Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
f0ef3442
Commit
f0ef3442
authored
Apr 26, 2023
by
yuguo960516yuguo
Browse files
2.3.2-dtk-22.10.1
parent
ad08b8ce
Pipeline
#227
failed with stages
in 0 seconds
Changes
274
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7023 additions
and
0 deletions
+7023
-0
paddle/fluid/distributed/ps/service/heter_server.h
paddle/fluid/distributed/ps/service/heter_server.h
+685
-0
paddle/fluid/distributed/ps/service/ps_client.cc
paddle/fluid/distributed/ps/service/ps_client.cc
+93
-0
paddle/fluid/distributed/ps/service/ps_client.h
paddle/fluid/distributed/ps/service/ps_client.h
+376
-0
paddle/fluid/distributed/ps/service/ps_local_client.cc
paddle/fluid/distributed/ps/service/ps_local_client.cc
+328
-0
paddle/fluid/distributed/ps/service/ps_local_client.h
paddle/fluid/distributed/ps/service/ps_local_client.h
+237
-0
paddle/fluid/distributed/ps/service/ps_local_server.h
paddle/fluid/distributed/ps/service/ps_local_server.h
+44
-0
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc
...uid/distributed/ps/service/ps_service/graph_py_service.cc
+511
-0
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h
...luid/distributed/ps/service/ps_service/graph_py_service.h
+215
-0
paddle/fluid/distributed/ps/service/ps_service/service.cc
paddle/fluid/distributed/ps/service/ps_service/service.cc
+137
-0
paddle/fluid/distributed/ps/service/ps_service/service.h
paddle/fluid/distributed/ps/service/ps_service/service.h
+79
-0
paddle/fluid/distributed/ps/service/sendrecv.proto
paddle/fluid/distributed/ps/service/sendrecv.proto
+157
-0
paddle/fluid/distributed/ps/service/server.cc
paddle/fluid/distributed/ps/service/server.cc
+111
-0
paddle/fluid/distributed/ps/service/server.h
paddle/fluid/distributed/ps/service/server.h
+212
-0
paddle/fluid/distributed/ps/table/CMakeLists.txt
paddle/fluid/distributed/ps/table/CMakeLists.txt
+112
-0
paddle/fluid/distributed/ps/table/accessor.h
paddle/fluid/distributed/ps/table/accessor.h
+180
-0
paddle/fluid/distributed/ps/table/barrier_table.cc
paddle/fluid/distributed/ps/table/barrier_table.cc
+78
-0
paddle/fluid/distributed/ps/table/common_graph_table.cc
paddle/fluid/distributed/ps/table/common_graph_table.cc
+2225
-0
paddle/fluid/distributed/ps/table/common_graph_table.h
paddle/fluid/distributed/ps/table/common_graph_table.h
+789
-0
paddle/fluid/distributed/ps/table/common_table.h
paddle/fluid/distributed/ps/table/common_table.h
+108
-0
paddle/fluid/distributed/ps/table/ctr_accessor.cc
paddle/fluid/distributed/ps/table/ctr_accessor.cc
+346
-0
No files found.
Too many changes to show.
To preserve performance only
274 of 274+
files are displayed.
Plain diff
Email patch
paddle/fluid/distributed/ps/service/heter_server.h
0 → 100644
View file @
f0ef3442
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/profiler.h"
namespace
google
{
namespace
protobuf
{
class
Closure
;
class
RpcController
;
}
// namespace protobuf
}
// namespace google
namespace
paddle
{
namespace
framework
{
class
Executor
;
class
ProgramDesc
;
class
Scope
;
}
// namespace framework
}
// namespace paddle
DECLARE_double
(
eager_delete_tensor_gb
);
namespace
paddle
{
namespace
distributed
{
DECLARE_int32
(
pserver_timeout_ms
);
DECLARE_int32
(
heter_world_size
);
DECLARE_int32
(
switch_send_recv_timeout_s
);
using
MultiVarMsg
=
MultiVariableMessage
;
using
VarMsg
=
VariableMessage
;
using
serviceHandler
=
std
::
function
<
int32_t
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
>
;
using
HeterServiceHandler
=
std
::
function
<
int32_t
(
const
MultiVarMsg
*
,
MultiVarMsg
*
,
brpc
::
Controller
*
)
>
;
using
HeterRpcCallbackFunc
=
std
::
function
<
void
(
void
*
)
>
;
class
ServiceHandlerBase
{
public:
ServiceHandlerBase
()
:
dev_ctx_
(
nullptr
),
scope_
(
nullptr
)
{}
virtual
~
ServiceHandlerBase
()
{}
void
SetScope
(
const
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
virtual
int
Handle
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
)
=
0
;
protected:
const
platform
::
DeviceContext
*
dev_ctx_
;
const
framework
::
Scope
*
scope_
;
};
using
SharedMiniScope
=
std
::
shared_ptr
<
std
::
unordered_map
<
int
,
::
paddle
::
framework
::
Scope
*>>
;
using
SharedMicroScope
=
std
::
shared_ptr
<
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
std
::
vector
<::
paddle
::
framework
::
Scope
*>>>>
;
using
SharedTaskQueue
=
std
::
shared_ptr
<
std
::
unordered_map
<
int
,
std
::
shared_ptr
<::
paddle
::
framework
::
BlockingQueue
<
std
::
pair
<
std
::
string
,
int
>>>>>
;
class
ValueInSwitch
{
public:
ValueInSwitch
()
{}
~
ValueInSwitch
()
{}
char
*
data
()
{
return
_data
.
data
();
}
size_t
size
()
{
return
_data
.
size
();
}
void
resize
(
size_t
size
)
{
_data
.
resize
(
size
);
}
void
shrink_to_fit
()
{
_data
.
shrink_to_fit
();
}
private:
std
::
vector
<
char
>
_data
;
};
class
SendAndRecvVariableHandler
final
:
public
ServiceHandlerBase
{
public:
SendAndRecvVariableHandler
()
{
this
->
num_microbatch_
=
0
;
this
->
num_minibatch_
=
0
;
_local_shards
.
reset
(
new
shard_type
[
FLAGS_heter_world_size
]);
}
virtual
~
SendAndRecvVariableHandler
()
{}
void
SetMiniScopes
(
SharedMiniScope
mini_scopes
)
{
mini_scopes_
=
mini_scopes
;
num_minibatch_
=
mini_scopes_
->
size
();
}
void
SetMicroScopes
(
SharedMicroScope
micro_scopes
)
{
micro_scopes_
=
micro_scopes
;
for
(
auto
&
scope_pair
:
(
*
micro_scopes_
))
{
// auto mini_idx = scope_pair.first;
auto
&
micro_scopes
=
scope_pair
.
second
;
num_microbatch_
=
micro_scopes
->
size
();
break
;
}
}
int
GetThreadNum
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
return
(
*
task_queue_
).
size
();
}
int
SaveInSwitchWithScope
(
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
brpc
::
Controller
*
cntl
);
void
WaitForVarsConsumed
(
int32_t
group_id
,
const
std
::
string
&
var_name
)
{
// timeline_.Start();
while
(
true
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
scope_mutex_
);
if
(
vars_ready_flag
[
group_id
][
var_name
]
==
0
)
{
break
;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not consumed exceed 10 miniutes";
break;
}
*/
}
return
;
}
void
WaitForVarsProduced
(
int32_t
group_id
,
const
std
::
string
&
var_name
)
{
// timeline_.Start();
while
(
true
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
scope_mutex_
);
if
(
vars_ready_flag
[
group_id
][
var_name
]
==
1
)
{
break
;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not produced exceed 10 miniutes";
break;
}
*/
}
return
;
}
int
SaveInSwitchWithShard
(
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
brpc
::
Controller
*
cntl
);
int
QueryInSwitchWithShard
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
);
int
QueryInSwitchWithScope
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
);
void
SetTaskQueue
(
SharedTaskQueue
task_queue
)
{
task_queue_
=
task_queue
;
}
int
Handle
(
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
brpc
::
Controller
*
cntl
)
override
{
LOG
(
INFO
)
<<
"entered Handle"
;
platform
::
RecordEvent
record_event
(
"SendAndRecvVariableHandler->Handle"
,
platform
::
TracerEventType
::
Communication
,
1
);
FLAGS_eager_delete_tensor_gb
=
-
1
;
// get microID from request
// deserialize variable to micro scope
// Push to heter worker's task_queue
std
::
unique_ptr
<
paddle
::
framework
::
Scope
>
local_scope_ptr
(
new
paddle
::
framework
::
Scope
());
auto
&
local_scope
=
*
(
local_scope_ptr
.
get
());
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
CPUPlace
cpu_place
;
auto
&
cpu_dev_ctx
=
*
pool
.
Get
(
cpu_place
);
auto
message_name
=
request
->
message_name
();
auto
&
request_io_buffer
=
cntl
->
request_attachment
();
distributed
::
DeserializeFromMultiVarMsgAndIOBuf
(
*
request
,
&
request_io_buffer
,
cpu_dev_ctx
,
&
local_scope
);
auto
*
var
=
local_scope
.
FindVar
(
"microbatch_id"
);
PADDLE_ENFORCE_NE
(
var
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Not find variable microbatch_id in scope."
));
auto
*
tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
data
=
reinterpret_cast
<
const
float
*>
(
tensor
->
data
());
auto
micro_id
=
static_cast
<
int
>
(
data
[
0
]);
VLOG
(
4
)
<<
"micro_id in heter server: "
<<
micro_id
;
int
minibatch_index
=
micro_id
/
10
;
int
microbatch_index
=
micro_id
%
10
;
// check minibatch_index is in mini_scopes_
std
::
unique_lock
<
std
::
mutex
>
lk
(
scope_mutex_
);
if
((
*
mini_scopes_
).
find
(
minibatch_index
)
!=
(
*
mini_scopes_
).
end
())
{
lk
.
unlock
();
PADDLE_ENFORCE_EQ
(
(
*
micro_scopes_
).
find
(
minibatch_index
)
!=
(
*
micro_scopes_
).
end
(),
1
,
platform
::
errors
::
InvalidArgument
(
"minibatch index should in current trainer"
));
}
else
{
// create mini scope & micro scopes
auto
*
minibatch_scope
=
&
(
scope_
->
NewScope
());
(
*
mini_scopes_
)[
minibatch_index
]
=
minibatch_scope
;
(
*
micro_scopes_
)[
minibatch_index
].
reset
(
new
std
::
vector
<
paddle
::
framework
::
Scope
*>
{});
for
(
int
i
=
0
;
i
<
num_microbatch_
;
i
++
)
{
auto
*
micro_scope
=
&
(
minibatch_scope
->
NewScope
());
(
*
((
*
micro_scopes_
)[
minibatch_index
])).
push_back
(
micro_scope
);
}
(
*
task_queue_
)[
minibatch_index
].
reset
(
new
::
paddle
::
framework
::
BlockingQueue
<
std
::
pair
<
std
::
string
,
int
>>
());
lk
.
unlock
();
}
auto
*
micro_scope
=
(
*
((
*
micro_scopes_
)[
minibatch_index
]))[
microbatch_index
];
distributed
::
DeserializeFromMultiVarMsgAndIOBuf
(
*
request
,
&
request_io_buffer
,
*
dev_ctx_
,
micro_scope
);
// blocking queue handles multi thread
VLOG
(
4
)
<<
"Handle in HeterServer: "
<<
message_name
<<
", "
<<
microbatch_index
;
VLOG
(
4
)
<<
"task_queue_ size: "
<<
task_queue_
->
size
();
(
*
task_queue_
)[
minibatch_index
]
->
Push
(
std
::
make_pair
(
message_name
,
microbatch_index
));
auto
response_var_nums
=
request
->
recv_var_names_size
();
std
::
vector
<
std
::
string
>
response_var_names
(
response_var_nums
),
empty_var_names
{};
for
(
int
var_idx
=
0
;
var_idx
<
response_var_nums
;
++
var_idx
)
{
response_var_names
[
var_idx
]
=
request
->
recv_var_names
(
var_idx
);
}
auto
&
response_io_buffer
=
cntl
->
response_attachment
();
distributed
::
SerializeToMultiVarMsgAndIOBuf
(
message_name
,
response_var_names
,
empty_var_names
,
*
dev_ctx_
,
&
local_scope
,
response
,
&
response_io_buffer
);
VLOG
(
4
)
<<
"Handle over"
;
return
0
;
}
public:
using
shard_type
=
SparseTableShard
<
std
::
string
,
ValueInSwitch
>
;
std
::
shared_ptr
<
paddle
::
framework
::
Scope
>
local_scope_ptr
;
// for switch
std
::
unordered_map
<
uint32_t
,
std
::
unordered_map
<
std
::
string
,
uint32_t
>>
vars_ready_flag
;
std
::
unique_ptr
<
shard_type
[]
>
_local_shards
;
platform
::
Timer
timeline_
;
private:
// share with HeterPipelineTrainer
SharedMiniScope
mini_scopes_
{
nullptr
};
SharedMicroScope
micro_scopes_
{
nullptr
};
int
num_microbatch_
;
int
num_minibatch_
;
std
::
mutex
scope_mutex_
;
bool
is_first_stage_
=
false
;
bool
is_last_stage_
=
false
;
SharedTaskQueue
task_queue_
;
};
class
HeterService
:
public
PsService
{
public:
HeterService
()
{
_service_handler_map
[
PS_STOP_SERVER
]
=
std
::
bind
(
&
HeterService
::
stop_heter_worker
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
_service_handler_map
[
PS_START_PROFILER
]
=
std
::
bind
(
&
HeterService
::
start_profiler
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
_service_handler_map
[
PS_STOP_PROFILER
]
=
std
::
bind
(
&
HeterService
::
stop_profiler
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
);
service_handler_
.
local_scope_ptr
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
}
virtual
~
HeterService
()
{}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
auto
itr
=
_service_handler_map
.
find
(
request
->
cmd_id
());
if
(
itr
==
_service_handler_map
.
end
())
{
std
::
string
err_msg
(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"
);
err_msg
.
append
(
std
::
to_string
(
request
->
cmd_id
()));
return
;
}
serviceHandler
handler
=
itr
->
second
;
int
service_ret
=
handler
(
*
request
,
*
response
,
cntl
);
VLOG
(
4
)
<<
"handler in service ret: "
<<
service_ret
;
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
service_ret
);
response
->
set_err_msg
(
"server internal error"
);
}
}
virtual
void
SendAndRecvVariable
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
message_name
=
request
->
message_name
();
VLOG
(
0
)
<<
"SendAndRecvVariable message_name: "
<<
message_name
;
auto
itr
=
handler_map_
.
find
(
message_name
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
LOG
(
INFO
)
<<
"SendAndRecvVariable(client addr) ="
<<
cntl
->
remote_side
();
PADDLE_ENFORCE_NE
(
itr
,
handler_map_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
"which is not in HeterService::handler_map_"
,
message_name
));
itr
->
second
(
request
,
response
,
cntl
);
// We don't want to call done->Run() here, release the guard.
// done_guard.release();
}
virtual
void
RecvFromSwitch
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
MultiVarMsg
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
// int ret = service_handler_.QueryInSwitchWithScope(request, response,
// cntl);
int
ret
=
service_handler_
.
QueryInSwitchWithShard
(
request
,
response
,
cntl
);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// int ret = itr->second(request, response, cntl);
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"QueryInSwitchWithScope failed!"
;
}
// response->set_message_name(message_name);
}
virtual
void
SendToSwitch
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
VLOG
(
4
)
<<
"entering SendToSwitch"
;
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
shared_ptr
<
HeterClient
>
switch_client_ptr_
=
HeterClient
::
GetSwitchInstance
(
peer_endpoints_
,
PEER_ROLE_IS_SWITCH
);
if
(
switch_client_ptr_
->
peer_switch_channels_
.
empty
())
{
LOG
(
ERROR
)
<<
"switch_client_ptr_->peer_switch_channels_ null"
;
}
brpc
::
Channel
*
channel
=
switch_client_ptr_
->
peer_switch_channels_
[
0
].
get
();
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
// proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset)
OnHeterRpcDone
*
closure2
=
new
OnHeterRpcDone
([](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
int
ret
=
closure
->
CheckResponse
();
closure
->
set_promise_value
(
ret
);
if
(
closure
->
cntl
.
Failed
())
{
PADDLE_ENFORCE_NE
(
closure
->
cntl
.
Failed
(),
true
,
platform
::
errors
::
Unimplemented
(
"HeterClient::SendS2S meets brpc error, error message is %s"
,
closure
->
cntl
.
ErrorText
()));
}
});
auto
&
std_cntl
=
closure2
->
cntl
;
std_cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
std_cntl
.
request_attachment
().
append
(
cntl
->
request_attachment
().
movable
());
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure2
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
// brpc::Controller std_cntl;
// std_cntl.request_attachment().append(cntl->request_attachment().movable());
PsService_Stub
stub
(
channel
);
stub
.
SendS2S
(
&
std_cntl
,
request
,
response
,
closure2
);
cntl
->
response_attachment
().
append
(
std_cntl
.
response_attachment
().
movable
());
fut
.
wait
();
VLOG
(
4
)
<<
"SendToSwitch done"
;
delete
closure2
;
}
void
SendS2S
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
VLOG
(
4
)
<<
"entering SendS2S"
;
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
// int ret = service_handler_.SaveInSwitchWithScope(request, response,
// cntl);
int
ret
=
service_handler_
.
SaveInSwitchWithShard
(
request
,
response
,
cntl
);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// if (itr == handler_map_.end()) {
// LOG(ERROR) << "can not find func handler";
//}
// int ret = itr->second(request, response, cntl);
if
(
ret
!=
0
)
{
LOG
(
ERROR
)
<<
"SaveInSwitchWithScope failed"
;
}
std
::
string
err_msg
=
"ok"
;
response
->
set_err_msg
(
err_msg
.
c_str
());
response
->
set_err_code
(
ret
);
VLOG
(
4
)
<<
"heter server SendS2S done"
;
}
void
SendToWorker
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
MultiVarMsg
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
controller
);
VLOG
(
4
)
<<
"SendToWorker(client addr) ="
<<
cntl
->
remote_side
();
std
::
shared_ptr
<
distributed
::
HeterClient
>
switch_client_ptr_
=
HeterClient
::
GetSwitchInstance
(
peer_endpoints_
,
PEER_ROLE_IS_WORKER
);
VLOG
(
4
)
<<
"in switch client, peer worker 0: "
<<
switch_client_ptr_
->
peer_worker_list_
[
0
];
brpc
::
Channel
*
channel
=
switch_client_ptr_
->
peer_worker_channels_
[
0
].
get
();
auto
*
closure
=
reinterpret_cast
<
OnHeterRpcDone
*>
(
done
);
PsService_Stub
stub
(
channel
);
stub
.
SendAndRecvVariable
(
controller
,
request
,
&
closure
->
response
,
done
);
// fill response content
std
::
string
err_msg
(
"pass to worker"
);
response
->
set_err_msg
(
err_msg
.
c_str
());
response
->
set_err_code
(
0
);
}
void
RegisterServiceHandler
(
std
::
string
message_name
,
HeterServiceHandler
func
)
{
handler_map_
[
message_name
]
=
func
;
}
void
SetEndpoint
(
const
std
::
string
&
end_point
)
{
endpoint_
=
end_point
;
}
void
SetInterEndpoint
(
const
std
::
string
&
end_point
)
{
endpoint_inter_
=
end_point
;
}
void
SetPeerEndPoints
(
const
std
::
vector
<
std
::
string
>&
peer_endpoints
)
{
peer_endpoints_
=
peer_endpoints
;
}
void
SetFanin
(
const
int
&
fan_in
)
{
fan_in_
=
fan_in
;
}
void
ForceExit
()
{
VLOG
(
3
)
<<
"heter service force exit"
;
is_exit_
=
true
;
return
;
}
bool
IsExit
()
{
return
is_exit_
;
}
private:
int32_t
stop_profiler
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"heter_worker_%s_profile"
,
endpoint_
));
return
0
;
}
int32_t
start_profiler
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kAll
);
return
0
;
}
int32_t
stop_heter_worker
(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
)
{
auto
client_id
=
request
.
client_id
();
stop_cpu_worker_set_
.
insert
(
client_id
);
if
(
stop_cpu_worker_set_
.
size
()
==
fan_in_
)
{
is_exit_
=
true
;
}
return
0
;
}
private:
SendAndRecvVariableHandler
service_handler_
;
std
::
string
endpoint_
;
std
::
string
endpoint_inter_
;
// for switch
std
::
vector
<
std
::
string
>
peer_endpoints_
;
std
::
unordered_map
<
int32_t
,
serviceHandler
>
_service_handler_map
;
std
::
unordered_map
<
std
::
string
,
HeterServiceHandler
>
handler_map_
;
std
::
unordered_set
<
int
>
stop_cpu_worker_set_
;
uint32_t
fan_in_
;
bool
is_exit_
=
false
;
};
class
HeterServer
{
public:
HeterServer
()
:
ready_
(
0
)
{}
virtual
~
HeterServer
()
{}
void
Stop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
stoped_
==
true
)
return
;
if
(
!
IsExit
())
{
service_
.
ForceExit
();
}
stoped_
=
true
;
cv_
.
notify_all
();
server_
.
Stop
(
1000
);
server_
.
Join
();
}
bool
IsStop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
stoped_
;
}
bool
IsExit
()
{
return
service_
.
IsExit
();
}
void
RegisterServiceHandler
(
std
::
string
message_name
,
HeterServiceHandler
func
);
void
StartHeterService
(
bool
need_encrypt
=
false
);
void
StartHeterInterService
(
bool
need_encrypt
=
false
);
void
SetEndPoint
(
const
std
::
string
&
endpoint
)
{
this
->
endpoint_
=
endpoint
;
service_
.
SetEndpoint
(
endpoint
);
}
void
SetLocalScope
()
{
request_handler_
->
local_scope_ptr
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
}
void
SetInterEndpoint
(
const
std
::
string
&
endpoint
)
{
this
->
endpoint_inter_
=
endpoint
;
service_
.
SetInterEndpoint
(
endpoint
);
}
void
SetPeerEndPoints
(
const
std
::
vector
<
std
::
string
>&
peer_endpoints
)
{
this
->
peer_endpoints_
=
peer_endpoints
;
service_
.
SetPeerEndPoints
(
peer_endpoints
);
}
void
SetFanin
(
const
int
&
fan_in
);
void
SetServiceHandler
(
std
::
shared_ptr
<
SendAndRecvVariableHandler
>
request_handler
)
{
request_handler_
=
request_handler
;
}
void
SetMiniBatchScopes
(
SharedMiniScope
mini_scopes
)
{
request_handler_
->
SetMiniScopes
(
mini_scopes
);
}
void
SetMicroBatchScopes
(
SharedMicroScope
micro_scopes
)
{
request_handler_
->
SetMicroScopes
(
micro_scopes
);
}
int
GetThreadNum
()
{
return
request_handler_
->
GetThreadNum
();
}
void
SetTaskQueue
(
SharedTaskQueue
task_queue
)
{
request_handler_
->
SetTaskQueue
(
task_queue
);
}
// HeterWrapper singleton
static
std
::
shared_ptr
<
HeterServer
>
GetInstance
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mtx_
);
if
(
s_instance_
==
nullptr
)
{
s_instance_
.
reset
(
new
HeterServer
());
}
return
s_instance_
;
}
void
WaitServerReady
();
private:
static
std
::
shared_ptr
<
HeterServer
>
s_instance_
;
mutable
std
::
mutex
mutex_
;
static
std
::
mutex
mtx_
;
std
::
condition_variable
cv_
;
std
::
condition_variable
condition_ready_
;
bool
stoped_
=
true
;
std
::
string
endpoint_
;
std
::
string
endpoint_inter_
;
// for switch
std
::
vector
<
std
::
string
>
peer_endpoints_
;
protected:
brpc
::
Server
server_
;
brpc
::
Server
server_inter_
;
HeterService
service_
;
std
::
shared_ptr
<
SendAndRecvVariableHandler
>
request_handler_
;
DISABLE_COPY_AND_ASSIGN
(
HeterServer
);
std
::
mutex
mutex_ready_
;
int
ready_
;
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/ps/service/ps_client.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace
paddle
{
namespace
distributed
{
REGISTER_PSCORE_CLASS
(
PSClient
,
BrpcPsClient
);
REGISTER_PSCORE_CLASS
(
PSClient
,
PsLocalClient
);
REGISTER_PSCORE_CLASS
(
PSClient
,
GraphBrpcClient
);
REGISTER_PSCORE_CLASS
(
PSClient
,
CoordinatorClient
);
int32_t
PSClient
::
Configure
(
// called in FleetWrapper::InitWorker
const
PSParameter
&
config
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
&
regions
,
PSEnvironment
&
env
,
size_t
client_id
)
{
_env
=
&
env
;
_config
=
config
;
_dense_pull_regions
=
regions
;
_client_id
=
client_id
;
_config
.
mutable_worker_param
()
->
mutable_downpour_worker_param
()
->
mutable_downpour_table_param
()
->
CopyFrom
(
_config
.
server_param
()
.
downpour_server_param
()
.
downpour_table_param
());
const
auto
&
work_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
int
i
=
0
;
i
<
work_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
accessor
=
CREATE_PSCORE_CLASS
(
ValueAccessor
,
work_param
.
downpour_table_param
(
i
).
accessor
().
accessor_class
());
accessor
->
Configure
(
work_param
.
downpour_table_param
(
i
).
accessor
());
accessor
->
Initialize
();
_table_accessors
[
work_param
.
downpour_table_param
(
i
).
table_id
()].
reset
(
accessor
);
}
return
Initialize
();
}
PSClient
*
PSClientFactory
::
Create
(
const
PSParameter
&
ps_config
)
{
const
auto
&
config
=
ps_config
.
server_param
();
if
(
!
config
.
has_downpour_server_param
())
{
LOG
(
ERROR
)
<<
"miss downpour_server_param in ServerParameter"
;
return
NULL
;
}
if
(
!
config
.
downpour_server_param
().
has_service_param
())
{
LOG
(
ERROR
)
<<
"miss service_param in ServerParameter.downpour_server_param"
;
return
NULL
;
}
if
(
!
config
.
downpour_server_param
().
service_param
().
has_client_class
())
{
LOG
(
ERROR
)
<<
"miss client_class in "
"ServerParameter.downpour_server_param.service_param"
;
return
NULL
;
}
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
PSClient
*
client
=
CREATE_PSCORE_CLASS
(
PSClient
,
service_param
.
client_class
());
if
(
client
==
NULL
)
{
LOG
(
ERROR
)
<<
"client is not registered, server_name:"
<<
service_param
.
client_class
();
return
NULL
;
}
TableManager
::
Instance
().
Initialize
();
VLOG
(
3
)
<<
"Create PSClient["
<<
service_param
.
client_class
()
<<
"] success"
;
return
client
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_client.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
namespace
distributed
{
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
PSClientCallBack
;
class
PSClientClosure
:
public
google
::
protobuf
::
Closure
{
public:
explicit
PSClientClosure
(
PSClientCallBack
callback
)
:
_callback
(
callback
)
{}
virtual
~
PSClientClosure
()
{}
virtual
void
set_promise_value
(
int
value
)
{
for
(
auto
&
promise
:
_promises
)
{
promise
->
set_value
(
value
);
}
}
void
add_promise
(
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
)
{
// NOLINT
_promises
.
push_back
(
promise
);
}
void
add_timer
(
std
::
shared_ptr
<
CostTimer
>
&
timer
)
{
// NOLINT
_timers
.
push_back
(
timer
);
}
protected:
PSClientCallBack
_callback
;
std
::
vector
<
std
::
shared_ptr
<
CostTimer
>>
_timers
;
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
};
class
PSClient
{
public:
PSClient
()
{}
virtual
~
PSClient
()
{}
PSClient
(
PSClient
&&
)
=
delete
;
PSClient
(
const
PSClient
&
)
=
delete
;
virtual
int32_t
Configure
(
const
PSParameter
&
config
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
&
regions
,
PSEnvironment
&
_env
,
// NOLINT
size_t
client_id
)
final
;
virtual
int32_t
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
=
0
;
// 触发table数据退场
virtual
std
::
future
<
int32_t
>
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
=
0
;
// 全量table进行数据load
virtual
std
::
future
<
int32_t
>
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 指定table数据load
virtual
std
::
future
<
int32_t
>
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual
std
::
future
<
int32_t
>
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual
std
::
future
<
int32_t
>
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
=
0
;
// 清空table数据
virtual
std
::
future
<
int32_t
>
Clear
()
=
0
;
virtual
std
::
future
<
int32_t
>
Clear
(
uint32_t
table_id
)
=
0
;
// pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数
// future结束前keys和values缓冲区不能再次使用
// client将values按照区块拆包后送交多个sender
// sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual
std
::
future
<
int32_t
>
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
// 保留
// firstly push dense param for parameter server
// this is necessary because dense weight initialized in trainer on cold
// start
virtual
std
::
future
<
int32_t
>
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
=
0
;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
// 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual
std
::
future
<
int32_t
>
PullSparse
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
=
0
;
virtual
std
::
future
<
int32_t
>
PullSparseParam
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
::
std
::
future
<
int32_t
>
PullSparsePtr
(
char
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PrintTableStat
(
uint32_t
table_id
)
=
0
;
// 确保所有积攒中的请求都发起发送
virtual
std
::
future
<
int32_t
>
Flush
()
=
0
;
// server优雅退出
virtual
std
::
future
<
int32_t
>
StopServer
()
=
0
;
// server profilera
virtual
std
::
future
<
int32_t
>
StartProfiler
()
=
0
;
virtual
std
::
future
<
int32_t
>
StopProfiler
()
=
0
;
virtual
std
::
future
<
int32_t
>
Barrier
(
size_t
table_id
,
uint32_t
barrier_type
)
=
0
;
virtual
std
::
future
<
int32_t
>
PullGeoParam
(
size_t
table_id
,
std
::
vector
<
float
>
*
values
,
std
::
vector
<
uint64_t
>
*
keys
,
int
pserver_idx
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushGlobalStep
(
int
table_id
,
int64_t
*
total_send_data
,
void
*
done
)
=
0
;
// recv table from server and save it in LodTensor
virtual
int32_t
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
=
0
;
virtual
void
FinalizeWorker
()
=
0
;
// client to client, 消息发送
virtual
std
::
future
<
int32_t
>
SendClient2ClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
virtual
int
RegisteClient2ClientMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
_msg_handler_map
[
msg_type
]
=
handler
;
return
0
;
}
virtual
int
HandleClient2ClientMsg
(
int
msg_type
,
int
from_client_id
,
const
std
::
string
&
msg
)
{
auto
itr
=
_msg_handler_map
.
find
(
msg_type
);
if
(
itr
==
_msg_handler_map
.
end
())
{
LOG
(
WARNING
)
<<
"unknown client2client_msg type:"
<<
msg_type
;
return
-
1
;
}
return
itr
->
second
(
msg_type
,
from_client_id
,
msg
);
}
virtual
ValueAccessor
*
GetTableAccessor
(
size_t
table_id
)
{
auto
itr
=
_table_accessors
.
find
(
table_id
);
if
(
itr
==
_table_accessors
.
end
())
{
return
NULL
;
}
return
itr
->
second
.
get
();
}
virtual
size_t
GetServerNums
()
=
0
;
virtual
std
::
future
<
int32_t
>
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
done
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradientPartial
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparseParam
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
=
0
;
virtual
std
::
future
<
int32_t
>
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
)
=
0
;
// for save cache
virtual
std
::
future
<
int32_t
>
CacheShuffle
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
CacheShuffleMultiTable
(
std
::
vector
<
int
>
tables
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
SaveCache
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
)
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
GetCacheThreshold
(
uint32_t
table_id
,
double
&
cache_threshold
)
{
// NOLINT
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
Revert
()
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
CheckSavePrePatchDone
()
{
VLOG
(
0
)
<<
"Did not implement"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
protected:
virtual
int32_t
Initialize
()
=
0
;
PSParameter
_config
;
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
_dense_pull_regions
;
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
ValueAccessor
>>
_table_accessors
;
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
// 处理client2client消息
public:
size_t
_client_id
;
PSEnvironment
*
_env
;
};
template
<
class
T
>
class
AsyncRequestTask
{
public:
AsyncRequestTask
()
:
_promise
(
std
::
make_shared
<
std
::
promise
<
int32_t
>>
())
{}
AsyncRequestTask
(
T
&
data
,
size_t
table_id
,
std
::
shared_ptr
<
CostTimer
>
&
timer
)
:
_table_id
(
table_id
),
_timer
(
timer
),
_promise
(
std
::
make_shared
<
std
::
promise
<
int32_t
>>
())
{
_data
=
std
::
move
(
data
);
}
AsyncRequestTask
(
AsyncRequestTask
&
data
)
// NOLINT
:
_table_id
(
data
.
table_id
()),
_timer
(
data
.
timer
()),
_promise
(
data
.
promise
())
{
_data
=
std
::
move
(
data
.
data
());
}
~
AsyncRequestTask
()
{}
inline
T
&
data
()
{
return
_data
;
}
inline
size_t
table_id
()
{
return
_table_id
;
}
inline
std
::
shared_ptr
<
CostTimer
>
&
timer
()
{
return
_timer
;
}
inline
std
::
future
<
int32_t
>
get_future
()
{
return
_promise
->
get_future
();
}
inline
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
()
{
return
_promise
;
}
private:
T
_data
;
size_t
_table_id
;
std
::
shared_ptr
<
CostTimer
>
_timer
;
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
_promise
;
};
REGISTER_PSCORE_REGISTERER
(
PSClient
);
class
PSClientFactory
{
public:
static
PSClient
*
Create
(
const
PSParameter
&
config
);
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_local_client.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
//#define pslib_debug_dense_compress
namespace
paddle
{
namespace
distributed
{
int32_t
PsLocalClient
::
Initialize
()
{
const
auto
&
downpour_param
=
_config
.
server_param
().
downpour_server_param
();
TableManager
::
Instance
().
Initialize
();
for
(
int
i
=
0
;
i
<
downpour_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
table
=
CREATE_PSCORE_CLASS
(
Table
,
downpour_param
.
downpour_table_param
(
i
).
table_class
());
table
->
SetShard
(
0
,
1
);
table
->
Initialize
(
downpour_param
.
downpour_table_param
(
i
),
_config
.
fs_client_param
());
_table_map
[
downpour_param
.
downpour_table_param
(
i
).
table_id
()].
reset
(
table
);
}
return
0
;
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
{
// TODO
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
for
(
auto
&
it
:
_table_map
)
{
Load
(
it
.
first
,
epoch
,
mode
);
}
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
auto
*
table_ptr
=
GetTable
(
table_id
);
table_ptr
->
Load
(
epoch
,
mode
);
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
for
(
auto
&
it
:
_table_map
)
{
Save
(
it
.
first
,
epoch
,
mode
);
}
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
// TODO
auto
*
table_ptr
=
GetTable
(
table_id
);
table_ptr
->
Flush
();
table_ptr
->
Save
(
epoch
,
mode
);
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Clear
()
{
// TODO
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Clear
(
uint32_t
table_id
)
{
// TODO
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
Flush
()
{
// no need
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
StopServer
()
{
// no need
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
*
table_ptr
=
GetTable
(
table_id
);
uint32_t
num_per_shard
=
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
1
);
std
::
vector
<
float
>
region_buffer
;
region_buffer
.
resize
(
num_per_shard
);
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
pull_context
.
values
=
region_buffer
.
data
();
table_context
.
num
=
region_buffer
.
size
();
table_ptr
->
Pull
(
table_context
);
// table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t
region_idx
=
0
;
size_t
region_data_idx
=
0
;
size_t
shard_data_size
=
num_per_shard
;
size_t
shard_buffer_remain
=
shard_data_size
*
sizeof
(
float
);
PADDLE_ENFORCE_EQ
(
shard_buffer_remain
,
region_buffer
.
size
()
*
sizeof
(
float
),
platform
::
errors
::
PreconditionNotMet
(
"pull dense size error."
));
size_t
index
=
0
;
while
(
shard_buffer_remain
>
0
&&
region_idx
<
region_num
)
{
auto
&
region
=
regions
[
region_idx
];
if
(
region
.
size
-
region_data_idx
>=
shard_buffer_remain
)
{
memcpy
((
void
*
)(
region
.
data
+
region_data_idx
),
(
uint8_t
*
)(
void
*
)(
region_buffer
.
data
())
+
index
,
shard_buffer_remain
);
region_data_idx
+=
shard_buffer_remain
;
shard_buffer_remain
=
0
;
}
else
if
(
region
.
size
-
region_data_idx
==
0
)
{
++
region_idx
;
region_data_idx
=
0
;
}
else
{
memcpy
((
void
*
)(
region
.
data
+
region_data_idx
),
(
uint8_t
*
)(
void
*
)(
region_buffer
.
data
())
+
index
,
region
.
size
-
region_data_idx
);
shard_buffer_remain
-=
(
region
.
size
-
region_data_idx
);
index
+=
(
region
.
size
-
region_data_idx
);
++
region_idx
;
region_data_idx
=
0
;
}
}
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
*
table_ptr
=
GetTable
(
table_id
);
std
::
vector
<
float
>
region_buffer
;
region_buffer
.
resize
(
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
1
),
0
);
for
(
size_t
i
=
0
,
offset
=
0
;
i
<
region_num
;
++
i
)
{
uint32_t
data_num
=
regions
[
i
].
size
/
sizeof
(
float
);
memcpy
(
region_buffer
.
data
()
+
offset
,
regions
[
i
].
data
,
regions
[
i
].
size
);
offset
+=
data_num
;
}
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
region_buffer
.
data
();
table_context
.
push_context
.
is_param
=
true
;
table_context
.
num
=
region_buffer
.
size
();
table_ptr
->
Push
(
table_context
);
// table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
callback
)
{
VLOG
(
1
)
<<
"wxx push_dense_raw_gradient"
;
PSClientClosure
*
closure
=
reinterpret_cast
<
PSClientClosure
*>
(
callback
);
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
total_send_data
;
table_context
.
num
=
total_send_data_size
;
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr
->
Push
(
table_context
);
delete
closure
;
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
*
table_ptr
=
GetTable
(
table_id
);
std
::
vector
<
float
>
region_buffer
;
region_buffer
.
resize
(
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
1
));
size_t
data_size
=
region_buffer
.
size
();
for
(
size_t
i
=
0
,
offset
=
0
;
i
<
region_num
;
++
i
)
{
uint32_t
data_num
=
regions
[
i
].
size
/
sizeof
(
float
);
PADDLE_ENFORCE_LE
(
offset
+
data_num
,
data_size
,
platform
::
errors
::
PreconditionNotMet
(
"invalid dense size, cur pos[%d] data_num[%d] size[%d]"
,
offset
,
data_num
,
data_size
));
memcpy
(
region_buffer
.
data
()
+
offset
,
regions
[
i
].
data
,
regions
[
i
].
size
);
offset
+=
data_num
;
}
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
region_buffer
.
data
();
table_context
.
num
=
region_buffer
.
size
();
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr
->
Push
(
table_context
);
return
done
();
}
//::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id,
// const uint64_t* keys,
// size_t num) {
// // FIXME
// // auto timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size();
//
// // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float));
// table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float));
// size_t offset = 0;
// for (int i = 0; i < num; ++i) {
// memcpy(select_values[i], (char*)res_data.data() + offset, value_size);
// offset += value_size;
// }
//
// // return fut;
// return done();
//}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PullSparsePtr
(
char
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
)
{
// FIXME
// auto timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// auto local_timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//将key拆分到各shard请求,并记录原始对应value指针
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
pull_context
.
keys
=
keys
;
table_context
.
pull_context
.
ptr_values
=
select_values
;
table_context
.
use_ptr
=
true
;
table_context
.
num
=
num
;
// table_ptr->PullSparsePtr(select_values, keys, num);
table_ptr
->
Pull
(
table_context
);
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
callback
)
{
PSClientClosure
*
closure
=
reinterpret_cast
<
PSClientClosure
*>
(
callback
);
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
keys
;
table_context
.
push_context
.
ptr_values
=
update_values
;
table_context
.
num
=
num
;
table_context
.
use_ptr
=
true
;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr
->
Push
(
table_context
);
delete
closure
;
return
done
();
}
::
std
::
future
<
int32_t
>
PsLocalClient
::
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
)
{
auto
*
table_ptr
=
GetTable
(
table_id
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
keys
;
table_context
.
push_context
.
ptr_values
=
update_values
;
table_context
.
num
=
num
;
table_context
.
use_ptr
=
true
;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr
->
Push
(
table_context
);
return
done
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_local_client.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
namespace
paddle
{
namespace
distributed
{
class
Table
;
class
PsLocalClient
:
public
PSClient
{
public:
PsLocalClient
()
{}
virtual
~
PsLocalClient
()
{
_running
=
false
;
}
virtual
int32_t
CreateClient2ClientConnection
(
int
pslib_timeout_ms
,
int
pslib_connect_timeout_ms
,
int
max_retry
)
{
return
0
;
}
virtual
::
std
::
future
<
int32_t
>
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
override
;
virtual
::
std
::
future
<
int32_t
>
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
virtual
::
std
::
future
<
int32_t
>
Clear
()
override
;
virtual
::
std
::
future
<
int32_t
>
Clear
(
uint32_t
table_id
)
override
;
virtual
::
std
::
future
<
int32_t
>
StopServer
()
override
;
virtual
void
FinalizeWorker
()
override
{}
virtual
::
std
::
future
<
int32_t
>
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
::
std
::
future
<
int32_t
>
PullSparse
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
::
std
::
future
<
int32_t
>
PullSparsePtr
(
char
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
);
virtual
::
std
::
future
<
int32_t
>
PrintTableStat
(
uint32_t
table_id
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
::
std
::
future
<
int32_t
>
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
);
virtual
::
std
::
future
<
int32_t
>
Flush
();
// server profilera
virtual
std
::
future
<
int32_t
>
StartProfiler
()
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
};
virtual
std
::
future
<
int32_t
>
StopProfiler
()
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
Barrier
(
size_t
table_id
,
uint32_t
barrier_type
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PullGeoParam
(
size_t
table_id
,
std
::
vector
<
float
>*
values
,
std
::
vector
<
uint64_t
>*
keys
,
int
pserver_idx
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PushGlobalStep
(
int
table_id
,
int64_t
*
total_send_data
,
void
*
done
)
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
// recv table from server and save it in LodTensor
virtual
int32_t
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
{
return
0
;
}
virtual
::
std
::
future
<
int32_t
>
SendClient2ClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
override
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
size_t
GetServerNums
()
{
return
1
;
}
virtual
std
::
future
<
int32_t
>
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
callback
)
override
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
callback
)
override
;
virtual
std
::
future
<
int32_t
>
PushSparseRawGradientPartial
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
override
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
virtual
std
::
future
<
int32_t
>
PushSparseParam
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
override
{
std
::
promise
<
int32_t
>
prom
;
std
::
future
<
int32_t
>
fut
=
prom
.
get_future
();
prom
.
set_value
(
0
);
return
fut
;
}
private:
virtual
int32_t
Initialize
()
override
;
std
::
future
<
int32_t
>
done
()
{
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
prom
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int32_t
>
fut
=
prom
->
get_future
();
prom
->
set_value
(
0
);
return
fut
;
}
inline
uint32_t
DenseDimPerShard
(
uint32_t
dense_dim_total
,
uint32_t
shard_num
)
{
return
dense_dim_total
/
shard_num
+
1
;
}
inline
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>*
GetTable
()
{
return
&
_table_map
;
}
inline
Table
*
GetTable
(
size_t
table_id
)
{
auto
itr
=
_table_map
.
find
(
table_id
);
if
(
itr
!=
_table_map
.
end
())
{
return
itr
->
second
.
get
();
}
LOG
(
ERROR
)
<<
"table not found "
<<
table_id
;
return
NULL
;
}
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>
_table_map
;
bool
_running
=
false
;
bool
_flushing
=
false
;
private:
float
_mae
=
0
;
float
_mse
=
0
;
uint16_t
_push_times
=
0
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_local_server.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/distributed/ps/service/server.h"
namespace
paddle
{
namespace
distributed
{
class
PsLocalServer
:
public
PSServer
{
public:
PsLocalServer
()
{}
virtual
~
PsLocalServer
()
{}
virtual
uint64_t
Start
()
{
return
0
;
}
virtual
uint64_t
Start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
return
0
;
}
virtual
int32_t
Stop
()
{
return
0
;
}
virtual
int32_t
Configure
(
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
=
{})
{
return
0
;
}
private:
virtual
int32_t
Initialize
()
{
return
0
;
}
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h"
#include <thread> // NOLINT
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace
paddle
{
namespace
distributed
{
std
::
vector
<
std
::
string
>
GraphPyService
::
split
(
std
::
string
&
str
,
const
char
pattern
)
{
std
::
vector
<
std
::
string
>
res
;
std
::
stringstream
input
(
str
);
std
::
string
temp
;
while
(
std
::
getline
(
input
,
temp
,
pattern
))
{
res
.
push_back
(
temp
);
}
return
res
;
}
void
GraphPyService
::
add_table_feat_conf
(
std
::
string
table_name
,
std
::
string
feat_name
,
std
::
string
feat_dtype
,
int
feat_shape
)
{
if
(
feature_to_id
.
find
(
table_name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
table_name
];
VLOG
(
0
)
<<
"for table name"
<<
table_name
<<
" idx = "
<<
idx
;
if
(
table_feat_mapping
[
idx
].
find
(
feat_name
)
==
table_feat_mapping
[
idx
].
end
())
{
VLOG
(
0
)
<<
"for table name not found,make a new one"
;
int
res
=
(
int
)
table_feat_mapping
[
idx
].
size
();
table_feat_mapping
[
idx
][
feat_name
]
=
res
;
VLOG
(
0
)
<<
"seq id = "
<<
table_feat_mapping
[
idx
][
feat_name
];
}
int
feat_idx
=
table_feat_mapping
[
idx
][
feat_name
];
VLOG
(
0
)
<<
"table_name "
<<
table_name
<<
" mapping id "
<<
idx
;
VLOG
(
0
)
<<
" feat name "
<<
feat_name
<<
" feat id"
<<
feat_idx
;
if
(
static_cast
<
size_t
>
(
feat_idx
)
<
table_feat_conf_feat_name
[
idx
].
size
())
{
// overide
table_feat_conf_feat_name
[
idx
][
feat_idx
]
=
feat_name
;
table_feat_conf_feat_dtype
[
idx
][
feat_idx
]
=
feat_dtype
;
table_feat_conf_feat_shape
[
idx
][
feat_idx
]
=
feat_shape
;
}
else
{
// new
table_feat_conf_feat_name
[
idx
].
push_back
(
feat_name
);
table_feat_conf_feat_dtype
[
idx
].
push_back
(
feat_dtype
);
table_feat_conf_feat_shape
[
idx
].
push_back
(
feat_shape
);
}
}
VLOG
(
0
)
<<
"add conf over"
;
}
void
add_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
bool
>
weight_list
)
{}
void
remove_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
)
{}
void
GraphPyService
::
set_up
(
std
::
string
ips_str
,
int
shard_num
,
std
::
vector
<
std
::
string
>
node_types
,
std
::
vector
<
std
::
string
>
edge_types
)
{
set_shard_num
(
shard_num
);
set_num_node_types
(
node_types
.
size
());
/*
int num_node_types;
std::unordered_map<std::string, uint32_t> edge_idx, feature_idx;
std::vector<std::unordered_map<std::string,uint32_t>> table_feat_mapping;
std::vector<std::vector<std::string>> table_feat_conf_feat_name;
std::vector<std::vector<std::string>> table_feat_conf_feat_dtype;
std::vector<std::vector<int32_t>> table_feat_conf_feat_shape;
*/
id_to_edge
=
edge_types
;
for
(
size_t
table_id
=
0
;
table_id
<
edge_types
.
size
();
table_id
++
)
{
int
res
=
(
int
)
edge_to_id
.
size
();
edge_to_id
[
edge_types
[
table_id
]]
=
res
;
}
id_to_feature
=
node_types
;
for
(
size_t
table_id
=
0
;
table_id
<
node_types
.
size
();
table_id
++
)
{
int
res
=
(
int
)
feature_to_id
.
size
();
feature_to_id
[
node_types
[
table_id
]]
=
res
;
}
table_feat_mapping
.
resize
(
node_types
.
size
());
this
->
table_feat_conf_feat_name
.
resize
(
node_types
.
size
());
this
->
table_feat_conf_feat_dtype
.
resize
(
node_types
.
size
());
this
->
table_feat_conf_feat_shape
.
resize
(
node_types
.
size
());
std
::
istringstream
stream
(
ips_str
);
std
::
string
ip
;
server_size
=
0
;
std
::
vector
<
std
::
string
>
ips_list
=
split
(
ips_str
,
';'
);
int
index
=
0
;
VLOG
(
0
)
<<
"start to build server"
;
for
(
auto
ips
:
ips_list
)
{
auto
ip_and_port
=
split
(
ips
,
':'
);
server_list
.
push_back
(
ip_and_port
[
0
]);
port_list
.
push_back
(
ip_and_port
[
1
]);
uint32_t
port
=
stoul
(
ip_and_port
[
1
]);
auto
ph_host
=
paddle
::
distributed
::
PSHost
(
ip_and_port
[
0
],
port
,
index
);
host_sign_list
.
push_back
(
ph_host
.
SerializeToString
());
index
++
;
}
VLOG
(
0
)
<<
"build server done"
;
}
void
GraphPyClient
::
start_client
()
{
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
dense_regions
;
dense_regions
.
insert
(
std
::
pair
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>
(
0
,
{}));
auto
regions
=
dense_regions
[
0
];
::
paddle
::
distributed
::
PSParameter
worker_proto
=
GetWorkerProto
();
paddle
::
distributed
::
PaddlePSEnvironment
_ps_env
;
auto
servers_
=
host_sign_list
.
size
();
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
SetPsServers
(
&
host_sign_list
,
servers_
);
worker_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>
(
(
paddle
::
distributed
::
GraphBrpcClient
*
)
paddle
::
distributed
::
PSClientFactory
::
Create
(
worker_proto
));
worker_ptr
->
Configure
(
worker_proto
,
dense_regions
,
_ps_env
,
client_id
);
worker_ptr
->
set_shard_num
(
get_shard_num
());
}
void
GraphPyServer
::
start_server
(
bool
block
)
{
std
::
string
ip
=
server_list
[
rank
];
uint32_t
port
=
std
::
stoul
(
port_list
[
rank
]);
::
paddle
::
distributed
::
PSParameter
server_proto
=
this
->
GetServerProto
();
auto
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
SetPsServers
(
&
this
->
host_sign_list
,
this
->
host_sign_list
.
size
());
// test
pserver_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcServer
>
(
(
paddle
::
distributed
::
GraphBrpcServer
*
)
paddle
::
distributed
::
PSServerFactory
::
Create
(
server_proto
));
VLOG
(
0
)
<<
"pserver-ptr created "
;
std
::
vector
<
framework
::
ProgramDesc
>
empty_vec
;
framework
::
ProgramDesc
empty_prog
;
empty_vec
.
push_back
(
empty_prog
);
pserver_ptr
->
Configure
(
server_proto
,
_ps_env
,
rank
,
empty_vec
);
pserver_ptr
->
Start
(
ip
,
port
);
pserver_ptr
->
build_peer2peer_connection
(
rank
);
std
::
condition_variable
*
cv_
=
pserver_ptr
->
export_cv
();
if
(
block
)
{
std
::
mutex
mutex_
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
->
wait
(
lock
);
}
}
::
paddle
::
distributed
::
PSParameter
GraphPyServer
::
GetServerProto
()
{
// Generate server proto desc
::
paddle
::
distributed
::
PSParameter
server_fleet_desc
;
::
paddle
::
distributed
::
ServerParameter
*
server_proto
=
server_fleet_desc
.
mutable_server_param
();
::
paddle
::
distributed
::
DownpourServerParameter
*
downpour_server_proto
=
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"GraphBrpcService"
);
server_service_proto
->
set_server_class
(
"GraphBrpcServer"
);
server_service_proto
->
set_client_class
(
"GraphBrpcClient"
);
server_service_proto
->
set_start_server_port
(
0
);
server_service_proto
->
set_server_thread_num
(
12
);
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::
paddle
::
distributed
::
TableParameter
*
sparse_table_proto
=
downpour_server_proto
->
add_downpour_table_param
();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto
(
sparse_table_proto
);
//}
return
server_fleet_desc
;
}
::
paddle
::
distributed
::
PSParameter
GraphPyClient
::
GetWorkerProto
()
{
::
paddle
::
distributed
::
PSParameter
worker_fleet_desc
;
::
paddle
::
distributed
::
WorkerParameter
*
worker_proto
=
worker_fleet_desc
.
mutable_worker_param
();
::
paddle
::
distributed
::
DownpourWorkerParameter
*
downpour_worker_proto
=
worker_proto
->
mutable_downpour_worker_param
();
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::
paddle
::
distributed
::
TableParameter
*
worker_sparse_table_proto
=
downpour_worker_proto
->
add_downpour_table_param
();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto
(
worker_sparse_table_proto
);
//}
::
paddle
::
distributed
::
ServerParameter
*
server_proto
=
worker_fleet_desc
.
mutable_server_param
();
::
paddle
::
distributed
::
DownpourServerParameter
*
downpour_server_proto
=
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"GraphBrpcService"
);
server_service_proto
->
set_server_class
(
"GraphBrpcServer"
);
server_service_proto
->
set_client_class
(
"GraphBrpcClient"
);
server_service_proto
->
set_start_server_port
(
0
);
server_service_proto
->
set_server_thread_num
(
12
);
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::
paddle
::
distributed
::
TableParameter
*
sparse_table_proto
=
downpour_server_proto
->
add_downpour_table_param
();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto
(
sparse_table_proto
);
//}
return
worker_fleet_desc
;
}
void
GraphPyClient
::
load_edge_file
(
std
::
string
name
,
std
::
string
filepath
,
bool
reverse
)
{
// 'e' means load edge
std
::
string
params
=
"e"
;
if
(
reverse
)
{
// 'e<' means load edges from $2 to $1
params
+=
"<"
+
name
;
}
else
{
// 'e>' means load edges from $1 to $2
params
+=
">"
+
name
;
}
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
auto
status
=
get_ps_client
()
->
Load
(
0
,
std
::
string
(
filepath
),
params
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// VLOG(0) << "loadding data with type " << name << " from " << filepath;
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->Load(table_id, std::string(filepath), params);
// status.wait();
// }
}
void
GraphPyClient
::
clear_nodes
(
std
::
string
name
)
{
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
clear_nodes
(
0
,
0
,
idx
);
status
.
wait
();
}
else
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
name
];
auto
status
=
get_ps_client
()
->
clear_nodes
(
0
,
1
,
idx
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = get_ps_client()->clear_nodes(table_id);
// status.wait();
// }
}
void
GraphPyClient
::
add_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>&
node_ids
,
std
::
vector
<
bool
>&
weight_list
)
{
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->add_graph_node(table_id, node_ids, weight_list);
// status.wait();
// }
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
add_graph_node
(
0
,
idx
,
node_ids
,
weight_list
);
status
.
wait
();
}
}
void
GraphPyClient
::
remove_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>&
node_ids
)
{
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
remove_graph_node
(
0
,
idx
,
node_ids
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = get_ps_client()->remove_graph_node(table_id, node_ids);
// status.wait();
// }
}
void
GraphPyClient
::
load_node_file
(
std
::
string
name
,
std
::
string
filepath
)
{
// 'n' means load nodes and 'node_type' follows
std
::
string
params
=
"n"
+
name
;
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
auto
status
=
get_ps_client
()
->
Load
(
0
,
std
::
string
(
filepath
),
params
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->Load(table_id, std::string(filepath), params);
// status.wait();
// }
}
std
::
pair
<
std
::
vector
<
std
::
vector
<
int64_t
>>
,
std
::
vector
<
float
>>
GraphPyClient
::
batch_sample_neighbors
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
bool
return_weight
,
bool
return_edges
)
{
std
::
vector
<
std
::
vector
<
int64_t
>>
v
;
std
::
vector
<
std
::
vector
<
float
>>
v1
;
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
batch_sample_neighbors
(
0
,
idx
,
node_ids
,
sample_size
,
v
,
v1
,
return_weight
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = worker_ptr->batch_sample_neighbors(
// table_id, node_ids, sample_size, v, v1, return_weight);
// status.wait();
// }
// res.first[0]: neighbors (nodes)
// res.first[1]: slice index
// res.first[2]: src nodes
// res.second: edges weight
std
::
pair
<
std
::
vector
<
std
::
vector
<
int64_t
>>
,
std
::
vector
<
float
>>
res
;
res
.
first
.
push_back
({});
res
.
first
.
push_back
({});
if
(
return_edges
)
res
.
first
.
push_back
({});
for
(
size_t
i
=
0
;
i
<
v
.
size
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
v
[
i
].
size
();
j
++
)
{
// res.first[0].push_back(v[i][j].first);
res
.
first
[
0
].
push_back
(
v
[
i
][
j
]);
if
(
return_edges
)
res
.
first
[
2
].
push_back
(
node_ids
[
i
]);
if
(
return_weight
)
res
.
second
.
push_back
(
v1
[
i
][
j
]);
}
if
(
i
==
v
.
size
()
-
1
)
break
;
if
(
i
==
0
)
{
res
.
first
[
1
].
push_back
(
v
[
i
].
size
());
}
else
{
res
.
first
[
1
].
push_back
(
v
[
i
].
size
()
+
res
.
first
[
1
].
back
());
}
}
return
res
;
}
std
::
vector
<
int64_t
>
GraphPyClient
::
random_sample_nodes
(
std
::
string
name
,
int
server_index
,
int
sample_size
)
{
std
::
vector
<
int64_t
>
v
;
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
name
];
auto
status
=
get_ps_client
()
->
random_sample_nodes
(
0
,
1
,
idx
,
server_index
,
sample_size
,
v
);
status
.
wait
();
}
else
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
random_sample_nodes
(
0
,
0
,
idx
,
server_index
,
sample_size
,
v
);
status
.
wait
();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// worker_ptr->random_sample_nodes(table_id, server_index, sample_size,
// v);
// status.wait();
// }
return
v
;
}
// (name, dtype, ndarray)
std
::
vector
<
std
::
vector
<
std
::
string
>>
GraphPyClient
::
get_node_feat
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
)
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
v
(
feature_names
.
size
(),
std
::
vector
<
std
::
string
>
(
node_ids
.
size
()));
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
name
];
auto
status
=
get_ps_client
()
->
get_node_feat
(
0
,
idx
,
node_ids
,
feature_names
,
v
);
status
.
wait
();
}
// if (this->table_id_map.count(node_type)) {
// uint32_t table_id = this->table_id_map[node_type];
// auto status =
// worker_ptr->get_node_feat(table_id, node_ids, feature_names, v);
// status.wait();
// }
return
v
;
}
void
GraphPyClient
::
set_node_feat
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
)
{
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
name
];
auto
status
=
get_ps_client
()
->
set_node_feat
(
0
,
idx
,
node_ids
,
feature_names
,
features
);
status
.
wait
();
}
// if (this->table_id_map.count(node_type)) {
// uint32_t table_id = this->table_id_map[node_type];
// auto status =
// worker_ptr->set_node_feat(table_id, node_ids, feature_names,
// features);
// status.wait();
// }
return
;
}
std
::
vector
<
FeatureNode
>
GraphPyClient
::
pull_graph_list
(
std
::
string
name
,
int
server_index
,
int
start
,
int
size
,
int
step
)
{
std
::
vector
<
FeatureNode
>
res
;
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = worker_ptr->pull_graph_list(table_id, server_index, start,
// size, step, res);
// status.wait();
// }
if
(
feature_to_id
.
find
(
name
)
!=
feature_to_id
.
end
())
{
int
idx
=
feature_to_id
[
name
];
auto
status
=
get_ps_client
()
->
pull_graph_list
(
0
,
1
,
idx
,
server_index
,
start
,
size
,
step
,
res
);
status
.
wait
();
}
else
if
(
edge_to_id
.
find
(
name
)
!=
edge_to_id
.
end
())
{
int
idx
=
edge_to_id
[
name
];
auto
status
=
get_ps_client
()
->
pull_graph_list
(
0
,
0
,
idx
,
server_index
,
start
,
size
,
step
,
res
);
status
.
wait
();
}
return
res
;
}
void
GraphPyClient
::
StopServer
()
{
VLOG
(
0
)
<<
"going to stop server"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
stoped_
)
return
;
auto
status
=
this
->
worker_ptr
->
StopServer
();
if
(
status
.
get
()
==
0
)
stoped_
=
true
;
}
void
GraphPyClient
::
FinalizeWorker
()
{
this
->
worker_ptr
->
FinalizeWorker
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <vector>
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
distributed
{
class
GraphPyService
{
protected:
std
::
vector
<
std
::
string
>
server_list
,
port_list
,
host_sign_list
;
int
server_size
,
shard_num
;
int
num_node_types
;
std
::
unordered_map
<
std
::
string
,
int
>
edge_to_id
,
feature_to_id
;
std
::
vector
<
std
::
string
>
id_to_feature
,
id_to_edge
;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
int
>>
table_feat_mapping
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
table_feat_conf_feat_name
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
table_feat_conf_feat_dtype
;
std
::
vector
<
std
::
vector
<
int
>>
table_feat_conf_feat_shape
;
public:
int
get_shard_num
()
{
return
shard_num
;
}
void
set_shard_num
(
int
shard_num
)
{
this
->
shard_num
=
shard_num
;
}
void
GetDownpourSparseTableProto
(
::
paddle
::
distributed
::
TableParameter
*
sparse_table_proto
)
{
sparse_table_proto
->
set_table_id
(
0
);
sparse_table_proto
->
set_table_class
(
"GraphTable"
);
sparse_table_proto
->
set_shard_num
(
shard_num
);
sparse_table_proto
->
set_type
(
::
paddle
::
distributed
::
PS_SPARSE_TABLE
);
::
paddle
::
distributed
::
TableAccessorParameter
*
accessor_proto
=
sparse_table_proto
->
mutable_accessor
();
// ::paddle::distributed::CommonAccessorParameter* common_proto =
// sparse_table_proto->mutable_common();
::
paddle
::
distributed
::
GraphParameter
*
graph_proto
=
sparse_table_proto
->
mutable_graph_parameter
();
// ::paddle::distributed::GraphFeature* graph_feature =
// graph_proto->mutable_graph_feature();
graph_proto
->
set_task_pool_size
(
24
);
graph_proto
->
set_table_name
(
"cpu_graph_table"
);
graph_proto
->
set_use_cache
(
false
);
for
(
size_t
i
=
0
;
i
<
id_to_edge
.
size
();
i
++
)
graph_proto
->
add_edge_types
(
id_to_edge
[
i
]);
for
(
size_t
i
=
0
;
i
<
id_to_feature
.
size
();
i
++
)
{
graph_proto
->
add_node_types
(
id_to_feature
[
i
]);
auto
feat_node
=
id_to_feature
[
i
];
::
paddle
::
distributed
::
GraphFeature
*
g_f
=
graph_proto
->
add_graph_feature
();
for
(
size_t
x
=
0
;
x
<
table_feat_conf_feat_name
[
i
].
size
();
x
++
)
{
g_f
->
add_name
(
table_feat_conf_feat_name
[
i
][
x
]);
g_f
->
add_dtype
(
table_feat_conf_feat_dtype
[
i
][
x
]);
g_f
->
add_shape
(
table_feat_conf_feat_shape
[
i
][
x
]);
}
}
// Set GraphTable Parameter
// common_proto->set_table_name(table_name);
// common_proto->set_name(table_type);
// for (size_t i = 0; i < feat_name.size(); i++) {
// common_proto->add_params(feat_dtype[i]);
// common_proto->add_dims(feat_shape[i]);
// common_proto->add_attributes(feat_name[i]);
// }
// for (size_t i = 0; i < feat_name.size(); i++) {
// graph_feature->add_dtype(feat_dtype[i]);
// graph_feature->add_shape(feat_shape[i]);
// graph_feature->add_name(feat_name[i]);
// }
accessor_proto
->
set_accessor_class
(
"CommMergeAccessor"
);
}
void
set_server_size
(
int
server_size
)
{
this
->
server_size
=
server_size
;
}
void
set_num_node_types
(
int
num_node_types
)
{
this
->
num_node_types
=
num_node_types
;
}
int
get_server_size
(
int
server_size
)
{
return
server_size
;
}
std
::
vector
<
std
::
string
>
split
(
std
::
string
&
str
,
const
char
pattern
);
void
set_up
(
std
::
string
ips_str
,
int
shard_num
,
std
::
vector
<
std
::
string
>
node_types
,
std
::
vector
<
std
::
string
>
edge_types
);
void
add_table_feat_conf
(
std
::
string
node_type
,
std
::
string
feat_name
,
std
::
string
feat_dtype
,
int32_t
feat_shape
);
};
class
GraphPyServer
:
public
GraphPyService
{
public:
GraphPyServer
()
{}
void
set_up
(
std
::
string
ips_str
,
int
shard_num
,
std
::
vector
<
std
::
string
>
node_types
,
std
::
vector
<
std
::
string
>
edge_types
,
int
rank
)
{
set_rank
(
rank
);
GraphPyService
::
set_up
(
ips_str
,
shard_num
,
node_types
,
edge_types
);
}
int
GetRank
()
{
return
rank
;
}
void
set_rank
(
int
rank
)
{
this
->
rank
=
rank
;
}
void
start_server
(
bool
block
=
true
);
::
paddle
::
distributed
::
PSParameter
GetServerProto
();
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcServer
>
get_ps_server
()
{
return
pserver_ptr
;
}
protected:
int
rank
;
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcServer
>
pserver_ptr
;
std
::
thread
*
server_thread
;
};
class
GraphPyClient
:
public
GraphPyService
{
public:
void
set_up
(
std
::
string
ips_str
,
int
shard_num
,
std
::
vector
<
std
::
string
>
node_types
,
std
::
vector
<
std
::
string
>
edge_types
,
int
client_id
)
{
set_client_id
(
client_id
);
GraphPyService
::
set_up
(
ips_str
,
shard_num
,
node_types
,
edge_types
);
}
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>
get_ps_client
()
{
return
worker_ptr
;
}
void
bind_local_server
(
int
local_channel_index
,
GraphPyServer
&
server
)
{
worker_ptr
->
set_local_channel
(
local_channel_index
);
worker_ptr
->
set_local_graph_service
(
(
paddle
::
distributed
::
GraphBrpcService
*
)
server
.
get_ps_server
()
->
get_service
());
}
void
StopServer
();
void
FinalizeWorker
();
void
load_edge_file
(
std
::
string
name
,
std
::
string
filepath
,
bool
reverse
);
void
load_node_file
(
std
::
string
name
,
std
::
string
filepath
);
void
clear_nodes
(
std
::
string
name
);
void
add_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>&
node_ids
,
std
::
vector
<
bool
>&
weight_list
);
void
remove_graph_node
(
std
::
string
name
,
std
::
vector
<
int64_t
>&
node_ids
);
int
get_client_id
()
{
return
client_id
;
}
void
set_client_id
(
int
client_id
)
{
this
->
client_id
=
client_id
;
}
void
start_client
();
std
::
pair
<
std
::
vector
<
std
::
vector
<
int64_t
>>
,
std
::
vector
<
float
>>
batch_sample_neighbors
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
int
sample_size
,
bool
return_weight
,
bool
return_edges
);
std
::
vector
<
int64_t
>
random_sample_nodes
(
std
::
string
name
,
int
server_index
,
int
sample_size
);
std
::
vector
<
std
::
vector
<
std
::
string
>>
get_node_feat
(
std
::
string
name
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
);
void
set_node_feat
(
std
::
string
node_type
,
std
::
vector
<
int64_t
>
node_ids
,
std
::
vector
<
std
::
string
>
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
features
);
std
::
vector
<
FeatureNode
>
pull_graph_list
(
std
::
string
name
,
int
server_index
,
int
start
,
int
size
,
int
step
=
1
);
::
paddle
::
distributed
::
PSParameter
GetWorkerProto
();
protected:
mutable
std
::
mutex
mutex_
;
int
client_id
;
std
::
shared_ptr
<
paddle
::
distributed
::
GraphBrpcClient
>
worker_ptr
;
std
::
thread
*
client_thread
;
bool
stoped_
=
false
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_service/service.cc
0 → 100644
View file @
f0ef3442
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include <fcntl.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <iostream>
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include "paddle/fluid/string/string_helper.h"
using
namespace
std
;
// NOLINT
namespace
paddle
{
namespace
distributed
{
paddle
::
distributed
::
PSParameter
load_from_prototxt
(
const
std
::
string
&
filename
)
{
paddle
::
distributed
::
PSParameter
param
;
int
file_descriptor
=
open
(
filename
.
c_str
(),
O_RDONLY
);
if
(
file_descriptor
==
-
1
)
{
VLOG
(
3
)
<<
"FATAL: fail to parse "
<<
filename
;
exit
(
-
1
);
}
google
::
protobuf
::
io
::
FileInputStream
fileInput
(
file_descriptor
);
if
(
!
google
::
protobuf
::
TextFormat
::
Parse
(
&
fileInput
,
&
param
))
{
VLOG
(
3
)
<<
"FATAL: fail to parse "
<<
filename
;
exit
(
-
1
);
}
close
(
file_descriptor
);
return
param
;
}
void
PSCore
::
InitGFlag
(
const
std
::
string
&
gflags
)
{
VLOG
(
3
)
<<
"Init With Gflags:"
<<
gflags
;
std
::
vector
<
std
::
string
>
flags
=
paddle
::
string
::
split_string
(
gflags
);
if
(
flags
.
size
()
<
1
)
{
flags
.
push_back
(
"-max_body_size=314217728"
);
flags
.
push_back
(
"-socket_max_unwritten_bytes=2048000000"
);
flags
.
push_back
(
"-max_connection_pool_size=1950"
);
}
auto
it
=
flags
.
begin
();
flags
.
insert
(
it
,
"exe default"
);
char
*
flags_ptr
[
flags
.
size
()];
for
(
size_t
i
=
0
;
i
<
flags
.
size
();
++
i
)
{
flags_ptr
[
i
]
=
(
char
*
)(
flags
[
i
].
c_str
());
// NOLINT
}
int
params_cnt
=
flags
.
size
();
char
**
params_ptr
=
&
(
flags_ptr
[
0
]);
::
GFLAGS_NAMESPACE
::
ParseCommandLineFlags
(
&
params_cnt
,
&
params_ptr
,
true
);
}
int
PSCore
::
InitServer
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>*
host_sign_list
,
int
node_num
,
int
index
,
int
trainers
,
const
std
::
vector
<
framework
::
ProgramDesc
>&
server_sub_program
)
{
google
::
protobuf
::
TextFormat
::
ParseFromString
(
dist_desc
,
&
_ps_param
);
InitGFlag
(
_ps_param
.
init_gflags
());
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
SetPsServers
(
host_sign_list
,
node_num
);
_ps_env
.
SetTrainers
(
trainers
);
int
ret
=
0
;
_server_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSServer
>
(
paddle
::
distributed
::
PSServerFactory
::
Create
(
_ps_param
));
ret
=
_server_ptr
->
Configure
(
_ps_param
,
_ps_env
,
index
,
server_sub_program
);
CHECK
(
ret
==
0
)
<<
"failed to configure server"
;
return
ret
;
}
int
PSCore
::
InitWorker
(
const
std
::
string
&
dist_desc
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>&
regions
,
const
std
::
vector
<
std
::
string
>*
host_sign_list
,
int
node_num
,
int
index
)
{
google
::
protobuf
::
TextFormat
::
ParseFromString
(
dist_desc
,
&
_ps_param
);
InitGFlag
(
_ps_param
.
init_gflags
());
_ps_env
=
paddle
::
distributed
::
PaddlePSEnvironment
();
_ps_env
.
SetPsServers
(
host_sign_list
,
node_num
);
int
ret
=
0
;
VLOG
(
1
)
<<
"PSCore::InitWorker"
;
auto
*
communicator
=
Communicator
::
GetInstance
();
ret
=
communicator
->
GetPsClient
()
->
Configure
(
_ps_param
,
regions
,
_ps_env
,
index
);
communicator
->
Start
();
return
ret
;
}
std
::
vector
<
uint64_t
>
PSCore
::
GetClientInfo
()
{
return
_ps_env
.
GetClientInfo
();
}
int
PSCore
::
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
{
int
ret
=
_worker_ptr
->
CreateClient2ClientConnection
(
pserver_timeout_ms
,
pserver_connect_timeout_ms
,
max_retry
);
return
ret
;
}
uint64_t
PSCore
::
RunServer
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
return
_server_ptr
->
Start
(
ip
,
port
);
}
int
PSCore
::
FinalizeWorker
()
{
_worker_ptr
->
FinalizeWorker
();
return
0
;
}
int
PSCore
::
StopServer
()
{
auto
stop_status
=
_worker_ptr
->
StopServer
();
stop_status
.
wait
();
return
0
;
}
paddle
::
distributed
::
PSParameter
*
PSCore
::
GetParam
()
{
return
&
_ps_param
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/ps_service/service.h
0 → 100644
View file @
f0ef3442
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/service/server.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
namespace
paddle
{
namespace
distributed
{
class
PSClient
;
class
PSServer
;
class
PsRequestMessage
;
class
PsResponseMessage
;
class
PsService
;
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
using
paddle
::
distributed
::
PsService
;
class
PSCore
{
public:
explicit
PSCore
()
{}
virtual
~
PSCore
()
{}
virtual
int
InitServer
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
std
::
string
>*
host_sign_list
,
int
node_num
,
int
index
,
int
trainers
,
const
std
::
vector
<
framework
::
ProgramDesc
>&
server_sub_program
=
{});
virtual
int
InitWorker
(
const
std
::
string
&
dist_desc
,
const
std
::
map
<
uint64_t
,
std
::
vector
<
paddle
::
distributed
::
Region
>>&
regions
,
const
std
::
vector
<
std
::
string
>*
host_sign_list
,
int
node_num
,
int
index
);
virtual
uint64_t
RunServer
(
const
std
::
string
&
ip
,
uint32_t
port
);
virtual
int
StopServer
();
virtual
int
FinalizeWorker
();
virtual
std
::
vector
<
uint64_t
>
GetClientInfo
();
virtual
int
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
);
std
::
shared_ptr
<
paddle
::
distributed
::
PSServer
>
_server_ptr
;
// pointer to server
std
::
shared_ptr
<
paddle
::
distributed
::
PSClient
>
_worker_ptr
;
// pointer to worker
virtual
paddle
::
distributed
::
PSParameter
*
GetParam
();
private:
void
InitGFlag
(
const
std
::
string
&
gflags
);
paddle
::
distributed
::
PSParameter
_ps_param
;
paddle
::
distributed
::
PaddlePSEnvironment
_ps_env
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/sendrecv.proto
0 → 100644
View file @
f0ef3442
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax
=
"proto2"
;
package
paddle
.
distributed
;
option
cc_generic_services
=
true
;
option
cc_enable_arenas
=
true
;
enum
PsCmdID
{
PS_PULL_DENSE_TABLE
=
0
;
PS_PUSH_DENSE_TABLE
=
1
;
PS_PULL_SPARSE_TABLE
=
2
;
PS_PUSH_SPARSE_TABLE
=
3
;
PS_SHRINK_TABLE
=
4
;
PS_SAVE_ONE_TABLE
=
5
;
PS_SAVE_ALL_TABLE
=
6
;
PS_LOAD_ONE_TABLE
=
7
;
PS_LOAD_ALL_TABLE
=
8
;
PS_CLEAR_ONE_TABLE
=
9
;
PS_CLEAR_ALL_TABLE
=
10
;
PS_PUSH_DENSE_PARAM
=
11
;
PS_STOP_SERVER
=
12
;
PS_SAVE_ONE_CACHE_TABLE
=
13
;
PS_GET_CACHE_THRESHOLD
=
14
;
PS_CACHE_SHUFFLE
=
15
;
PS_COPY_TABLE
=
16
;
PS_COPY_TABLE_BY_FEASIGN
=
17
;
PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY
=
18
;
PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY
=
19
;
PS_PRINT_TABLE_STAT
=
20
;
PS_SAVE_ONE_TABLE_PREFIX
=
21
;
PS_SAVE_ONE_TABLE_WITH_WHITELIST
=
22
;
PS_LOAD_ONE_TABLE_WITH_WHITELIST
=
23
;
PS_PULL_GEO_PARAM
=
24
;
PS_BARRIER
=
25
;
PS_PUSH_SPARSE_PARAM
=
26
;
PS_START_PROFILER
=
27
;
PS_STOP_PROFILER
=
28
;
PS_PUSH_GLOBAL_STEP
=
29
;
PS_PULL_GRAPH_LIST
=
30
;
PS_GRAPH_SAMPLE_NEIGHBORS
=
31
;
PS_GRAPH_SAMPLE_NODES
=
32
;
PS_GRAPH_GET_NODE_FEAT
=
33
;
PS_GRAPH_CLEAR
=
34
;
PS_GRAPH_ADD_GRAPH_NODE
=
35
;
PS_GRAPH_REMOVE_GRAPH_NODE
=
36
;
PS_GRAPH_SET_NODE_FEAT
=
37
;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER
=
38
;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE
=
39
;
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG
=
40
;
PEER_ROLE_IS_WORKER
=
41
;
PEER_ROLE_IS_SWITCH
=
42
;
PS_SAVE_WITH_SCOPE
=
43
;
PS_SAVE_WITH_SHARD
=
44
;
PS_QUERY_WITH_SCOPE
=
45
;
PS_QUERY_WITH_SHARD
=
46
;
PS_REVERT
=
47
;
PS_CHECK_SAVE_PRE_PATCH_DONE
=
48
;
// pserver2pserver cmd start from 100
PS_S2S_MSG
=
101
;
PUSH_FL_CLIENT_INFO_SYNC
=
200
;
PUSH_FL_STRATEGY
=
201
;
}
message
PsRequestMessage
{
required
uint32
cmd_id
=
1
;
optional
uint32
table_id
=
2
;
repeated
bytes
params
=
3
;
optional
int32
client_id
=
4
;
optional
bytes
data
=
5
;
};
message
PsResponseMessage
{
required
int32
err_code
=
1
[
default
=
0
];
required
string
err_msg
=
2
[
default
=
""
];
optional
bytes
data
=
3
;
};
message
CoordinatorReqMessage
{
required
uint32
cmd_id
=
1
;
optional
int32
client_id
=
2
;
optional
string
str_params
=
3
;
};
message
CoordinatorResMessage
{
required
int32
err_code
=
1
[
default
=
0
];
required
string
err_msg
=
2
[
default
=
""
];
optional
string
str_params
=
3
;
};
enum
VarType
{
LOD_TENSOR
=
0
;
SELECTED_ROWS
=
1
;
}
message
VariableMessage
{
enum
Type
{
// Pod Types
BOOL
=
0
;
INT16
=
1
;
INT32
=
2
;
INT64
=
3
;
FP16
=
4
;
FP32
=
5
;
FP64
=
6
;
}
message
LodData
{
repeated
int64
lod_data
=
1
;
}
optional
string
varname
=
1
;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
optional
VarType
type
=
2
;
// bool persistable is not needed for sending.
// tensor info:
optional
Type
data_type
=
3
;
repeated
int64
dims
=
4
;
// lod details:
optional
int64
lod_level
=
5
;
repeated
LodData
lod
=
6
;
// selected_rows height, aka. original dim0
optional
int64
slr_height
=
7
;
// tensor data
optional
bytes
data
=
8
;
}
// for SendAndRecv RPC method
message
MultiVariableMessage
{
// message flags
required
string
message_name
=
1
;
repeated
string
send_var_names
=
2
;
repeated
string
recv_var_names
=
3
;
repeated
VariableMessage
var_messages
=
4
;
optional
bytes
data
=
5
;
repeated
int64
vars_len
=
6
;
optional
int32
group_id
=
7
;
};
service
PsService
{
rpc
service
(
PsRequestMessage
)
returns
(
PsResponseMessage
);
rpc
FLService
(
CoordinatorReqMessage
)
returns
(
CoordinatorResMessage
);
rpc
SendAndRecvVariable
(
MultiVariableMessage
)
returns
(
MultiVariableMessage
);
rpc
SendToWorker
(
MultiVariableMessage
)
returns
(
PsResponseMessage
);
rpc
SendToSwitch
(
MultiVariableMessage
)
returns
(
PsResponseMessage
);
rpc
SendS2S
(
MultiVariableMessage
)
returns
(
PsResponseMessage
);
rpc
RecvFromSwitch
(
MultiVariableMessage
)
returns
(
MultiVariableMessage
);
};
paddle/fluid/distributed/ps/service/server.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_local_server.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace
paddle
{
namespace
distributed
{
REGISTER_PSCORE_CLASS
(
PSServer
,
BrpcPsServer
);
REGISTER_PSCORE_CLASS
(
PSServer
,
PsLocalServer
);
REGISTER_PSCORE_CLASS
(
PsBaseService
,
BrpcPsService
);
REGISTER_PSCORE_CLASS
(
PSServer
,
GraphBrpcServer
);
REGISTER_PSCORE_CLASS
(
PsBaseService
,
GraphBrpcService
);
PSServer
*
PSServerFactory
::
Create
(
const
PSParameter
&
ps_config
)
{
const
auto
&
config
=
ps_config
.
server_param
();
if
(
!
config
.
has_downpour_server_param
())
{
LOG
(
ERROR
)
<<
"miss downpour_server_param in ServerParameter"
;
return
NULL
;
}
if
(
!
config
.
downpour_server_param
().
has_service_param
())
{
LOG
(
ERROR
)
<<
"miss service_param in ServerParameter.downpour_server_param"
;
return
NULL
;
}
if
(
!
config
.
downpour_server_param
().
service_param
().
has_server_class
())
{
LOG
(
ERROR
)
<<
"miss server_class in "
"ServerParameter.downpour_server_param.service_param"
;
return
NULL
;
}
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
PSServer
*
server
=
CREATE_PSCORE_CLASS
(
PSServer
,
service_param
.
server_class
());
if
(
server
==
NULL
)
{
LOG
(
ERROR
)
<<
"server is not registered, server_name:"
<<
service_param
.
server_class
();
return
NULL
;
}
TableManager
::
Instance
().
Initialize
();
return
server
;
}
int32_t
PSServer
::
Configure
(
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
)
{
scope_
.
reset
(
new
framework
::
Scope
());
_config
=
config
.
server_param
();
_rank
=
server_rank
;
_environment
=
&
env
;
_shuffled_ins
=
paddle
::
framework
::
MakeChannel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
();
size_t
shard_num
=
env
.
GetPsServers
().
size
();
const
auto
&
downpour_param
=
_config
.
downpour_server_param
();
uint32_t
barrier_table
=
UINT32_MAX
;
uint32_t
global_step_table
=
UINT32_MAX
;
for
(
int
i
=
0
;
i
<
downpour_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
table
=
CREATE_PSCORE_CLASS
(
Table
,
downpour_param
.
downpour_table_param
(
i
).
table_class
());
if
(
downpour_param
.
downpour_table_param
(
i
).
table_class
()
==
"BarrierTable"
)
{
barrier_table
=
downpour_param
.
downpour_table_param
(
i
).
table_id
();
}
if
(
downpour_param
.
downpour_table_param
(
i
).
table_class
()
==
"GlobalStepTable"
)
{
global_step_table
=
downpour_param
.
downpour_table_param
(
i
).
table_id
();
}
table
->
SetProgramEnv
(
scope_
.
get
(),
place_
,
&
server_sub_program
);
table
->
SetShard
(
_rank
,
shard_num
);
table
->
Initialize
(
downpour_param
.
downpour_table_param
(
i
),
config
.
fs_client_param
());
_table_map
[
downpour_param
.
downpour_table_param
(
i
).
table_id
()].
reset
(
table
);
}
if
(
barrier_table
!=
UINT32_MAX
)
{
_table_map
[
barrier_table
]
->
SetTableMap
(
&
_table_map
);
}
if
(
global_step_table
!=
UINT32_MAX
)
{
_table_map
[
global_step_table
]
->
SetTableMap
(
&
_table_map
);
}
return
Initialize
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/server.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "butil/endpoint.h"
#include "google/protobuf/service.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace
google
{
namespace
protobuf
{
class
RpcController
;
}
// namespace protobuf
}
// namespace google
namespace
paddle
{
namespace
distributed
{
class
PSEnvironment
;
}
// namespace distributed
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
class
Executor
;
class
ProgramDesc
;
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
distributed
{
class
Table
;
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
class
PSServer
{
public:
PSServer
()
{}
virtual
~
PSServer
()
{}
PSServer
(
PSServer
&&
)
=
delete
;
PSServer
(
const
PSServer
&
)
=
delete
;
virtual
int32_t
Configure
(
const
PSParameter
&
config
,
PSEnvironment
&
env
,
size_t
server_rank
,
const
std
::
vector
<
framework
::
ProgramDesc
>
&
server_sub_program
=
{});
virtual
uint64_t
Start
(
const
std
::
string
&
ip
,
uint32_t
port
)
=
0
;
virtual
int32_t
Stop
()
=
0
;
inline
size_t
Rank
()
const
{
return
_rank
;
}
inline
PSEnvironment
*
Environment
()
{
return
_environment
;
}
inline
const
ServerParameter
*
Config
()
const
{
return
&
_config
;
}
inline
Table
*
GetTable
(
size_t
table_id
)
{
auto
itr
=
_table_map
.
find
(
table_id
);
if
(
itr
!=
_table_map
.
end
())
{
return
itr
->
second
.
get
();
}
return
NULL
;
}
inline
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>
*
GetTable
()
{
return
&
_table_map
;
}
// for cache
virtual
int32_t
StartS2S
()
{
return
0
;
}
virtual
::
std
::
future
<
int32_t
>
SendPServer2PServerMsg
(
int
msg_type
,
int
to_pserver_id
,
const
std
::
string
&
msg
)
{
LOG
(
FATAL
)
<<
"NotImplementError: PSServer::send_pserver2pserver_msg"
;
std
::
promise
<
int32_t
>
promise
;
std
::
future
<
int
>
fut
=
promise
.
get_future
();
promise
.
set_value
(
-
1
);
return
fut
;
}
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
virtual
int
RegistePServer2PServerMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
_msg_handler_map
[
msg_type
]
=
handler
;
return
0
;
}
virtual
int
HandlePServer2PServerMsg
(
int
msg_type
,
int
from_pserver_id
,
const
std
::
string
&
msg
)
{
auto
itr
=
_msg_handler_map
.
find
(
msg_type
);
if
(
itr
==
_msg_handler_map
.
end
())
{
if
(
msg_type
==
101
)
{
return
ReceiveFromPServer
(
msg_type
,
from_pserver_id
,
msg
);
}
else
{
LOG
(
WARNING
)
<<
"unknown pserver2pserver_msg type:"
<<
msg_type
;
return
-
1
;
}
}
return
itr
->
second
(
msg_type
,
from_pserver_id
,
msg
);
}
virtual
int32_t
ReceiveFromPServer
(
int
msg_type
,
int
pserver_id
,
const
std
::
string
&
msg
)
{
LOG
(
FATAL
)
<<
"NotImplementError::PSServer::ReceiveFromPServer"
;
return
-
1
;
}
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
_shuffled_ins
;
protected:
virtual
int32_t
Initialize
()
=
0
;
protected:
size_t
_rank
;
ServerParameter
_config
;
PSEnvironment
*
_environment
;
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>
_table_map
;
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
protected:
std
::
shared_ptr
<
framework
::
Scope
>
scope_
;
platform
::
Place
place_
=
platform
::
CPUPlace
();
};
REGISTER_PSCORE_REGISTERER
(
PSServer
);
typedef
std
::
function
<
void
(
void
*
)
>
PServerCallBack
;
class
PServerClosure
:
public
google
::
protobuf
::
Closure
{
public:
PServerClosure
(
PServerCallBack
callback
)
:
_callback
(
callback
)
{}
virtual
~
PServerClosure
()
{}
virtual
void
set_promise_value
(
int
value
)
{
for
(
auto
&
promise
:
_promises
)
{
promise
->
set_value
(
value
);
}
}
void
add_promise
(
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>
&
promise
)
{
_promises
.
push_back
(
promise
);
}
protected:
PServerCallBack
_callback
;
std
::
vector
<
std
::
shared_ptr
<
std
::
promise
<
int32_t
>>>
_promises
;
};
class
PsBaseService
:
public
PsService
{
public:
PsBaseService
()
:
_rank
(
0
),
_server
(
NULL
),
_config
(
NULL
)
{}
virtual
~
PsBaseService
()
{}
virtual
size_t
GetRank
()
{
return
_rank
;
}
virtual
int32_t
Configure
(
PSServer
*
server
)
{
_server
=
server
;
_rank
=
_server
->
Rank
();
_config
=
_server
->
Config
();
return
0
;
}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
=
0
;
virtual
void
set_response_code
(
PsResponseMessage
&
response
,
int
err_code
,
const
char
*
err_msg
)
{
response
.
set_err_msg
(
err_msg
);
response
.
set_err_code
(
err_code
);
LOG
(
WARNING
)
<<
"Resonse err_code:"
<<
err_code
<<
" msg:"
<<
err_msg
;
}
virtual
int32_t
Initialize
()
=
0
;
PSServer
*
GetServer
()
{
return
_server
;
}
protected:
size_t
_rank
;
PSServer
*
_server
;
const
ServerParameter
*
_config
;
};
REGISTER_PSCORE_REGISTERER
(
PsBaseService
);
class
PSServerFactory
{
public:
static
PSServer
*
Create
(
const
PSParameter
&
config
);
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/CMakeLists.txt
0 → 100644
View file @
f0ef3442
set_property
(
GLOBAL PROPERTY TABLE_DEPS string_helper
)
set
(
graphDir graph
)
get_property
(
TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS
)
set_source_files_properties
(
${
graphDir
}
/graph_edge.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
graph_edge SRCS
${
graphDir
}
/graph_edge.cc
)
set_source_files_properties
(
${
graphDir
}
/graph_weighted_sampler.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
WeightedSampler
SRCS
${
graphDir
}
/graph_weighted_sampler.cc
DEPS graph_edge
)
set_source_files_properties
(
${
graphDir
}
/graph_node.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
graph_node
SRCS
${
graphDir
}
/graph_node.cc
DEPS WeightedSampler enforce
)
set_source_files_properties
(
memory_dense_table.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
barrier_table.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
common_graph_table.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
get_property
(
RPC_DEPS GLOBAL PROPERTY RPC_DEPS
)
set
(
PADDLE_LIB_THIRD_PARTY_PATH
"
${
PADDLE_LIB
}
/third_party/"
)
include_directories
(
${
PADDLE_LIB_THIRD_PARTY_PATH
}
libmct/src/extern_libmct/libmct/include
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fopenmp"
)
set
(
TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc
)
#set(EXTERN_DEP rocksdb)
cc_library
(
common_table
SRCS
${
TABLE_SRC
}
DEPS
${
TABLE_DEPS
}
${
RPC_DEPS
}
graph_edge
graph_node
device_context
string_helper
simple_threadpool
xxhash
generator
)
set_source_files_properties
(
tensor_accessor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
tensor_table
SRCS
DEPS eigen3
ps_framework_proto
executor
scope
device_context
tensor
${
TABLE_DEPS
}
)
set_source_files_properties
(
table.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ctr_double_accessor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ctr_accessor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
sparse_accessor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ctr_dymf_accessor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
memory_sparse_table.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
table
SRCS sparse_sgd_rule.cc
ctr_accessor.cc
ctr_double_accessor.cc
sparse_accessor.cc
ctr_dymf_accessor.cc
tensor_accessor.cc
memory_sparse_table.cc
ssd_sparse_table.cc
memory_sparse_geo_table.cc
table.cc
DEPS
${
TABLE_DEPS
}
common_table
tensor_table
ps_framework_proto
string_helper
device_context
gflags
glog
fs
afs_wrapper
rocksdb
eigen3
)
target_link_libraries
(
table -fopenmp
)
paddle/fluid/distributed/ps/table/accessor.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/afs_warpper.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
namespace
paddle
{
namespace
distributed
{
struct
Region
{
Region
()
:
data
(
NULL
),
size
(
0
)
{}
Region
(
char
*
data
,
size_t
data_num
)
:
data
(
data
),
size
(
data_num
)
{}
Region
(
float
*
data
,
size_t
data_num
)
:
data
(
reinterpret_cast
<
char
*>
(
data
)),
size
(
data_num
<<
2
)
{}
Region
(
int16_t
*
data
,
size_t
data_num
)
:
data
(
reinterpret_cast
<
char
*>
(
data
)),
size
(
data_num
<<
1
)
{}
Region
(
int32_t
*
data
,
size_t
data_num
)
:
data
(
reinterpret_cast
<
char
*>
(
data
)),
size
(
data_num
<<
2
)
{}
Region
(
int64_t
*
data
,
size_t
data_num
)
:
data
(
reinterpret_cast
<
char
*>
(
data
)),
size
(
data_num
<<
3
)
{}
char
*
data
;
size_t
size
;
};
struct
DataConverter
{
int
param
;
std
::
string
converter
;
std
::
string
deconverter
;
};
struct
AccessorInfo
{
// value维度
size_t
dim
;
// value各个维度的size
size_t
size
;
// pull value维度
size_t
select_dim
;
// pull value各维度相加总size
size_t
select_size
;
// push value维度
size_t
update_dim
;
// push value各个维度的size
size_t
update_size
;
// value中mf动态长度部分总size大小, sparse下生效
size_t
mf_size
;
// value总维度,dense下生效
size_t
fea_dim
;
};
class
ValueAccessor
{
public:
ValueAccessor
()
{}
virtual
~
ValueAccessor
()
{}
virtual
int
Configure
(
const
TableAccessorParameter
&
parameter
)
{
_config
=
parameter
;
// data_convert结构体初始化
if
(
_config
.
table_accessor_save_param_size
()
!=
0
)
{
for
(
int
i
=
0
;
i
<
_config
.
table_accessor_save_param_size
();
++
i
)
{
int
param
=
_config
.
table_accessor_save_param
(
i
).
param
();
std
::
string
converter
=
_config
.
table_accessor_save_param
(
i
).
converter
();
std
::
string
deconverter
=
_config
.
table_accessor_save_param
(
i
).
deconverter
();
_data_coverter_map
[
param
]
=
std
::
make_shared
<
DataConverter
>
();
*
(
_data_coverter_map
[
param
])
=
{
param
,
converter
,
deconverter
};
}
}
return
0
;
}
virtual
int
Initialize
()
=
0
;
virtual
AccessorInfo
GetAccessorInfo
()
{
return
_accessor_info
;
}
virtual
bool
NeedExtendMF
(
float
*
value
)
{
return
false
;
}
virtual
bool
HasMF
(
size_t
size
)
{
return
false
;
}
// converter for save
virtual
std
::
string
GetConverter
(
int
param
)
{
auto
itr
=
_data_coverter_map
.
find
(
param
);
if
(
itr
==
_data_coverter_map
.
end
())
{
return
""
;
}
else
{
return
(
*
itr
).
second
->
converter
;
}
}
// deconverter for load
virtual
std
::
string
GetDeconverter
(
int
param
)
{
auto
itr
=
_data_coverter_map
.
find
(
param
);
if
(
itr
==
_data_coverter_map
.
end
())
{
return
""
;
}
else
{
return
(
*
itr
).
second
->
deconverter
;
}
}
// 判断该value是否进行shrink
virtual
bool
Shrink
(
float
*
value
)
=
0
;
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
virtual
bool
Save
(
float
*
value
,
int
param
)
=
0
;
// update delta_score and unseen_days after save
virtual
void
UpdateStatAfterSave
(
float
*
value
,
int
param
)
{}
// 判断该value是否保存到ssd
virtual
bool
SaveSSD
(
float
*
value
)
=
0
;
//
virtual
bool
SaveCache
(
float
*
value
,
int
param
,
double
global_cache_threshold
)
=
0
;
// keys不存在时,为values生成随机值
virtual
int32_t
Create
(
float
**
value
,
size_t
num
)
=
0
;
virtual
bool
CreateValue
(
int
type
,
const
float
*
value
)
{
return
true
;
}
// 从values中选取到select_values中
virtual
int32_t
Select
(
float
**
select_values
,
const
float
**
values
,
size_t
num
)
=
0
;
// 将update_values聚合到一起
virtual
int32_t
Merge
(
float
**
update_values
,
const
float
**
other_update_values
,
size_t
num
)
=
0
;
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual
int32_t
Update
(
float
**
values
,
const
float
**
update_values
,
size_t
num
)
=
0
;
// used to save model, will filter feature
virtual
std
::
string
ParseToString
(
const
float
*
value
,
int
param
)
=
0
;
// parse value from string, used to load model
virtual
int32_t
ParseFromString
(
const
std
::
string
&
data
,
float
*
value
)
=
0
;
virtual
FsDataConverter
Converter
(
int
param
)
{
FsDataConverter
data_convert
;
data_convert
.
converter
=
this
->
GetConverter
(
param
);
data_convert
.
deconverter
=
this
->
GetDeconverter
(
param
);
return
data_convert
;
}
virtual
int
SetWeight
(
float
**
values
,
const
float
**
update_values
,
size_t
num
)
{
return
0
;
}
virtual
float
GetField
(
float
*
value
,
const
std
::
string
&
name
)
{
return
0.0
;
}
#define DEFINE_GET_INDEX(class, field) \
virtual int get_##field##_index() override { return class ::field##_index(); }
protected:
size_t
_value_size
;
size_t
_select_value_size
;
size_t
_update_value_size
;
TableAccessorParameter
_config
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
struct
DataConverter
>>
_data_coverter_map
;
AccessorInfo
_accessor_info
;
};
REGISTER_PSCORE_REGISTERER
(
ValueAccessor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/barrier_table.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/common_table.h"
namespace
paddle
{
namespace
distributed
{
int32_t
BarrierTable
::
Initialize
()
{
auto
trainers
=
_config
.
common
().
trainer_num
();
trigger_
.
store
(
trainers
);
for
(
int
x
=
0
;
x
<
trainers
;
++
x
)
{
trainer_all_
.
insert
(
x
);
}
VLOG
(
1
)
<<
"BarrierTable init trigger: "
<<
trigger_
.
load
();
return
0
;
}
// 0: send_barrier 1: recv_barrier 2: complete
int32_t
BarrierTable
::
Barrier
(
const
uint32_t
trainer_id
,
const
std
::
string
barrier_type
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
barrier_type
==
"2"
)
{
trigger_
.
fetch_sub
(
1
,
std
::
memory_order
::
memory_order_relaxed
);
VLOG
(
1
)
<<
"trigger sub to : "
<<
trigger_
.
load
();
}
else
{
trainer_ids_
.
insert
(
trainer_id
);
VLOG
(
1
)
<<
"barrier type: "
<<
barrier_type
<<
" add trainer id: "
<<
trainer_id
;
}
if
(
static_cast
<
int
>
(
trainer_ids_
.
size
())
<
trigger_
.
load
())
{
std
::
vector
<
uint32_t
>
diffs
(
trainer_all_
.
size
());
auto
iter
=
std
::
set_difference
(
trainer_all_
.
begin
(),
trainer_all_
.
end
(),
trainer_ids_
.
begin
(),
trainer_ids_
.
end
(),
diffs
.
begin
());
diffs
.
resize
(
iter
-
diffs
.
begin
());
auto
diff
=
to_string
<
uint32_t
>
(
diffs
);
VLOG
(
1
)
<<
"still need trainers: "
<<
diff
;
trainer_wait_
.
wait
(
lock
,
[
&
]
{
return
trainer_ids_
.
size
()
==
0
;
});
}
else
{
VLOG
(
1
)
<<
"barrier table optimize begin"
;
for
(
auto
&
x
:
*
table_map_
)
{
auto
table
=
x
.
second
;
table
->
Pour
();
}
VLOG
(
1
)
<<
"barrier table optimize done"
;
trainer_ids_
.
clear
();
trainer_wait_
.
notify_all
();
}
return
0
;
}
int32_t
BarrierTable
::
SetTableMap
(
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>*
table_map
)
{
table_map_
=
table_map
;
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/common_graph_table.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include <time.h>
#include <algorithm>
#include <chrono>
#include <set>
#include <sstream>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
DECLARE_bool
(
graph_load_in_parallel
);
namespace
paddle
{
namespace
distributed
{
#ifdef PADDLE_WITH_HETERPS
int32_t
GraphTable
::
Load_to_ssd
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
bool
load_edge
=
(
param
[
0
]
==
'e'
);
bool
load_node
=
(
param
[
0
]
==
'n'
);
if
(
load_edge
)
{
bool
reverse_edge
=
(
param
[
1
]
==
'<'
);
std
::
string
edge_type
=
param
.
substr
(
2
);
return
this
->
load_edges_to_ssd
(
path
,
reverse_edge
,
edge_type
);
}
if
(
load_node
)
{
std
::
string
node_type
=
param
.
substr
(
1
);
return
this
->
load_nodes
(
path
,
node_type
);
}
return
0
;
}
paddle
::
framework
::
GpuPsCommGraphFea
GraphTable
::
make_gpu_ps_graph_fea
(
std
::
vector
<
uint64_t
>
&
node_ids
,
int
slot_num
)
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
bags
(
task_pool_size_
);
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
auto
predsize
=
node_ids
.
size
()
/
task_pool_size_
;
bags
[
i
].
reserve
(
predsize
*
1.2
);
}
for
(
auto
x
:
node_ids
)
{
int
location
=
x
%
shard_num
%
task_pool_size_
;
bags
[
location
].
push_back
(
x
);
}
std
::
vector
<
std
::
future
<
int
>>
tasks
;
std
::
vector
<
uint64_t
>
feature_array
[
task_pool_size_
];
std
::
vector
<
uint8_t
>
slot_id_array
[
task_pool_size_
];
std
::
vector
<
uint64_t
>
node_id_array
[
task_pool_size_
];
std
::
vector
<
paddle
::
framework
::
GpuPsFeaInfo
>
node_fea_info_array
[
task_pool_size_
];
for
(
size_t
i
=
0
;
i
<
bags
.
size
();
i
++
)
{
if
(
bags
[
i
].
size
()
>
0
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
,
i
,
this
]()
->
int
{
uint64_t
node_id
;
paddle
::
framework
::
GpuPsFeaInfo
x
;
std
::
vector
<
uint64_t
>
feature_ids
;
for
(
size_t
j
=
0
;
j
<
bags
[
i
].
size
();
j
++
)
{
// TODO use FEATURE_TABLE instead
Node
*
v
=
find_node
(
1
,
bags
[
i
][
j
]);
node_id
=
bags
[
i
][
j
];
if
(
v
==
NULL
)
{
x
.
feature_size
=
0
;
x
.
feature_offset
=
0
;
node_fea_info_array
[
i
].
push_back
(
x
);
}
else
{
// x <- v
x
.
feature_offset
=
feature_array
[
i
].
size
();
int
total_feature_size
=
0
;
for
(
int
k
=
0
;
k
<
slot_num
;
++
k
)
{
v
->
get_feature_ids
(
k
,
&
feature_ids
);
total_feature_size
+=
feature_ids
.
size
();
if
(
!
feature_ids
.
empty
())
{
feature_array
[
i
].
insert
(
feature_array
[
i
].
end
(),
feature_ids
.
begin
(),
feature_ids
.
end
());
slot_id_array
[
i
].
insert
(
slot_id_array
[
i
].
end
(),
feature_ids
.
size
(),
k
);
}
}
x
.
feature_size
=
total_feature_size
;
node_fea_info_array
[
i
].
push_back
(
x
);
}
node_id_array
[
i
].
push_back
(
node_id
);
}
return
0
;
}));
}
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
paddle
::
framework
::
GpuPsCommGraphFea
res
;
uint64_t
tot_len
=
0
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
tot_len
+=
feature_array
[
i
].
size
();
}
VLOG
(
0
)
<<
"Loaded feature table on cpu, feature_list_size["
<<
tot_len
<<
"] node_ids_size["
<<
node_ids
.
size
()
<<
"]"
;
res
.
init_on_cpu
(
tot_len
,
(
unsigned
int
)
node_ids
.
size
(),
slot_num
);
unsigned
int
offset
=
0
,
ind
=
0
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
(
int
)
node_id_array
[
i
].
size
();
j
++
)
{
res
.
node_list
[
ind
]
=
node_id_array
[
i
][
j
];
res
.
fea_info_list
[
ind
]
=
node_fea_info_array
[
i
][
j
];
res
.
fea_info_list
[
ind
++
].
feature_offset
+=
offset
;
}
for
(
size_t
j
=
0
;
j
<
feature_array
[
i
].
size
();
j
++
)
{
res
.
feature_list
[
offset
+
j
]
=
feature_array
[
i
][
j
];
res
.
slot_id_list
[
offset
+
j
]
=
slot_id_array
[
i
][
j
];
}
offset
+=
feature_array
[
i
].
size
();
}
return
res
;
}
paddle
::
framework
::
GpuPsCommGraph
GraphTable
::
make_gpu_ps_graph
(
int
idx
,
std
::
vector
<
uint64_t
>
ids
)
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
bags
(
task_pool_size_
);
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
auto
predsize
=
ids
.
size
()
/
task_pool_size_
;
bags
[
i
].
reserve
(
predsize
*
1.2
);
}
for
(
auto
x
:
ids
)
{
int
location
=
x
%
shard_num
%
task_pool_size_
;
bags
[
location
].
push_back
(
x
);
}
std
::
vector
<
std
::
future
<
int
>>
tasks
;
std
::
vector
<
uint64_t
>
node_array
[
task_pool_size_
];
// node id list
std
::
vector
<
paddle
::
framework
::
GpuPsNodeInfo
>
info_array
[
task_pool_size_
];
std
::
vector
<
uint64_t
>
edge_array
[
task_pool_size_
];
// edge id list
for
(
size_t
i
=
0
;
i
<
bags
.
size
();
i
++
)
{
if
(
bags
[
i
].
size
()
>
0
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
,
i
,
this
]()
->
int
{
node_array
[
i
].
resize
(
bags
[
i
].
size
());
info_array
[
i
].
resize
(
bags
[
i
].
size
());
edge_array
[
i
].
reserve
(
bags
[
i
].
size
());
for
(
size_t
j
=
0
;
j
<
bags
[
i
].
size
();
j
++
)
{
auto
node_id
=
bags
[
i
][
j
];
node_array
[
i
][
j
]
=
node_id
;
Node
*
v
=
find_node
(
0
,
idx
,
node_id
);
if
(
v
!=
nullptr
)
{
info_array
[
i
][
j
].
neighbor_offset
=
edge_array
[
i
].
size
();
info_array
[
i
][
j
].
neighbor_size
=
v
->
get_neighbor_size
();
for
(
size_t
k
=
0
;
k
<
v
->
get_neighbor_size
();
k
++
)
{
edge_array
[
i
].
push_back
(
v
->
get_neighbor_id
(
k
));
}
}
else
{
info_array
[
i
][
j
].
neighbor_offset
=
0
;
info_array
[
i
][
j
].
neighbor_size
=
0
;
}
}
return
0
;
}));
}
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
int64_t
tot_len
=
0
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
tot_len
+=
edge_array
[
i
].
size
();
}
paddle
::
framework
::
GpuPsCommGraph
res
;
res
.
init_on_cpu
(
tot_len
,
ids
.
size
());
int64_t
offset
=
0
,
ind
=
0
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
(
int
)
node_array
[
i
].
size
();
j
++
)
{
res
.
node_list
[
ind
]
=
node_array
[
i
][
j
];
res
.
node_info_list
[
ind
]
=
info_array
[
i
][
j
];
res
.
node_info_list
[
ind
++
].
neighbor_offset
+=
offset
;
}
for
(
size_t
j
=
0
;
j
<
edge_array
[
i
].
size
();
j
++
)
{
res
.
neighbor_list
[
offset
+
j
]
=
edge_array
[
i
][
j
];
}
offset
+=
edge_array
[
i
].
size
();
}
return
res
;
}
int32_t
GraphTable
::
add_node_to_ssd
(
int
type_id
,
int
idx
,
uint64_t
src_id
,
char
*
data
,
int
len
)
{
if
(
_db
!=
NULL
)
{
char
ch
[
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
)];
memcpy
(
ch
,
&
type_id
,
sizeof
(
int
));
memcpy
(
ch
+
sizeof
(
int
),
&
idx
,
sizeof
(
int
));
memcpy
(
ch
+
sizeof
(
int
)
*
2
,
&
src_id
,
sizeof
(
uint64_t
));
std
::
string
str
;
if
(
_db
->
get
(
src_id
%
shard_num
%
task_pool_size_
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
str
)
==
0
)
{
uint64_t
*
stored_data
=
((
uint64_t
*
)
str
.
c_str
());
int
n
=
str
.
size
()
/
sizeof
(
uint64_t
);
char
*
new_data
=
new
char
[
n
*
sizeof
(
uint64_t
)
+
len
];
memcpy
(
new_data
,
stored_data
,
n
*
sizeof
(
uint64_t
));
memcpy
(
new_data
+
n
*
sizeof
(
uint64_t
),
data
,
len
);
_db
->
put
(
src_id
%
shard_num
%
task_pool_size_
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
(
char
*
)
new_data
,
n
*
sizeof
(
uint64_t
)
+
len
);
delete
[]
new_data
;
}
else
{
_db
->
put
(
src_id
%
shard_num
%
task_pool_size_
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
(
char
*
)
data
,
len
);
}
}
return
0
;
}
char
*
GraphTable
::
random_sample_neighbor_from_ssd
(
int
idx
,
uint64_t
id
,
int
sample_size
,
const
std
::
shared_ptr
<
std
::
mt19937_64
>
rng
,
int
&
actual_size
)
{
if
(
_db
==
NULL
)
{
actual_size
=
0
;
return
NULL
;
}
std
::
string
str
;
VLOG
(
2
)
<<
"sample ssd for key "
<<
id
;
char
ch
[
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
)];
memset
(
ch
,
0
,
sizeof
(
int
));
memcpy
(
ch
+
sizeof
(
int
),
&
idx
,
sizeof
(
int
));
memcpy
(
ch
+
sizeof
(
int
)
*
2
,
&
id
,
sizeof
(
uint64_t
));
if
(
_db
->
get
(
id
%
shard_num
%
task_pool_size_
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
str
)
==
0
)
{
uint64_t
*
data
=
((
uint64_t
*
)
str
.
c_str
());
int
n
=
str
.
size
()
/
sizeof
(
uint64_t
);
std
::
unordered_map
<
int
,
int
>
m
;
// std::vector<uint64_t> res;
int
sm_size
=
std
::
min
(
n
,
sample_size
);
actual_size
=
sm_size
*
Node
::
id_size
;
char
*
buff
=
new
char
[
actual_size
];
for
(
int
i
=
0
;
i
<
sm_size
;
i
++
)
{
std
::
uniform_int_distribution
<
int
>
distrib
(
0
,
n
-
i
-
1
);
int
t
=
distrib
(
*
rng
);
// int t = rand() % (n-i);
int
pos
=
0
;
auto
iter
=
m
.
find
(
t
);
if
(
iter
!=
m
.
end
())
{
pos
=
iter
->
second
;
}
else
{
pos
=
t
;
}
auto
iter2
=
m
.
find
(
n
-
i
-
1
);
int
key2
=
iter2
==
m
.
end
()
?
n
-
i
-
1
:
iter2
->
second
;
m
[
t
]
=
key2
;
m
.
erase
(
n
-
i
-
1
);
memcpy
(
buff
+
i
*
Node
::
id_size
,
&
data
[
pos
],
Node
::
id_size
);
// res.push_back(data[pos]);
}
for
(
int
i
=
0
;
i
<
actual_size
;
i
+=
8
)
{
VLOG
(
2
)
<<
"sampled an neighbor "
<<
*
(
uint64_t
*
)
&
buff
[
i
];
}
return
buff
;
}
actual_size
=
0
;
return
NULL
;
}
int64_t
GraphTable
::
load_graph_to_memory_from_ssd
(
int
idx
,
std
::
vector
<
uint64_t
>
&
ids
)
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
bags
(
task_pool_size_
);
for
(
auto
x
:
ids
)
{
int
location
=
x
%
shard_num
%
task_pool_size_
;
bags
[
location
].
push_back
(
x
);
}
std
::
vector
<
std
::
future
<
int
>>
tasks
;
std
::
vector
<
int64_t
>
count
(
task_pool_size_
,
0
);
for
(
size_t
i
=
0
;
i
<
bags
.
size
();
i
++
)
{
if
(
bags
[
i
].
size
()
>
0
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
,
i
,
idx
,
this
]()
->
int
{
char
ch
[
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
)];
memset
(
ch
,
0
,
sizeof
(
int
));
memcpy
(
ch
+
sizeof
(
int
),
&
idx
,
sizeof
(
int
));
for
(
size_t
k
=
0
;
k
<
bags
[
i
].
size
();
k
++
)
{
auto
v
=
bags
[
i
][
k
];
memcpy
(
ch
+
sizeof
(
int
)
*
2
,
&
v
,
sizeof
(
uint64_t
));
std
::
string
str
;
if
(
_db
->
get
(
i
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
str
)
==
0
)
{
count
[
i
]
+=
(
int64_t
)
str
.
size
();
for
(
size_t
j
=
0
;
j
<
(
int
)
str
.
size
();
j
+=
sizeof
(
uint64_t
))
{
uint64_t
id
=
*
(
uint64_t
*
)(
str
.
c_str
()
+
j
);
add_comm_edge
(
idx
,
v
,
id
);
}
}
}
return
0
;
}));
}
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
int64_t
tot
=
0
;
for
(
auto
x
:
count
)
tot
+=
x
;
return
tot
;
}
void
GraphTable
::
make_partitions
(
int
idx
,
int64_t
byte_size
,
int
device_len
)
{
VLOG
(
2
)
<<
"start to make graph partitions , byte_size = "
<<
byte_size
<<
" total memory cost = "
<<
total_memory_cost
;
if
(
total_memory_cost
==
0
)
{
VLOG
(
0
)
<<
"no edges are detected,make partitions exits"
;
return
;
}
auto
&
weight_map
=
node_weight
[
0
][
idx
];
const
double
a
=
2.0
,
y
=
1.25
,
weight_param
=
1.0
;
int64_t
gb_size_by_discount
=
byte_size
*
0.8
*
device_len
;
if
(
gb_size_by_discount
<=
0
)
gb_size_by_discount
=
1
;
int
part_len
=
total_memory_cost
/
gb_size_by_discount
;
if
(
part_len
==
0
)
part_len
=
1
;
VLOG
(
2
)
<<
"part_len = "
<<
part_len
<<
" byte size = "
<<
gb_size_by_discount
;
partitions
[
idx
].
clear
();
partitions
[
idx
].
resize
(
part_len
);
std
::
vector
<
double
>
weight_cost
(
part_len
,
0
);
std
::
vector
<
int64_t
>
memory_remaining
(
part_len
,
gb_size_by_discount
);
std
::
vector
<
double
>
score
(
part_len
,
0
);
std
::
unordered_map
<
uint64_t
,
int
>
id_map
;
std
::
vector
<
rocksdb
::
Iterator
*>
iters
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
iters
.
push_back
(
_db
->
get_iterator
(
i
));
iters
[
i
]
->
SeekToFirst
();
}
int
next
=
0
;
while
(
iters
.
size
())
{
if
(
next
>=
(
int
)
iters
.
size
())
{
next
=
0
;
}
if
(
!
iters
[
next
]
->
Valid
())
{
iters
.
erase
(
iters
.
begin
()
+
next
);
continue
;
}
std
::
string
key
=
iters
[
next
]
->
key
().
ToString
();
int
type_idx
=
*
(
int
*
)
key
.
c_str
();
int
temp_idx
=
*
(
int
*
)(
key
.
c_str
()
+
sizeof
(
int
));
if
(
type_idx
!=
0
||
temp_idx
!=
idx
)
{
iters
[
next
]
->
Next
();
next
++
;
continue
;
}
std
::
string
value
=
iters
[
next
]
->
value
().
ToString
();
std
::
uint64_t
i_key
=
*
(
uint64_t
*
)(
key
.
c_str
()
+
sizeof
(
int
)
*
2
);
for
(
int
i
=
0
;
i
<
part_len
;
i
++
)
{
if
(
memory_remaining
[
i
]
<
(
int64_t
)
value
.
size
())
{
score
[
i
]
=
-
100000.0
;
}
else
{
score
[
i
]
=
0
;
}
}
for
(
size_t
j
=
0
;
j
<
(
int
)
value
.
size
();
j
+=
sizeof
(
uint64_t
))
{
uint64_t
v
=
*
((
uint64_t
*
)(
value
.
c_str
()
+
j
));
int
index
=
-
1
;
if
(
id_map
.
find
(
v
)
!=
id_map
.
end
())
{
index
=
id_map
[
v
];
score
[
index
]
++
;
}
}
double
base
,
weight_base
=
0
;
double
w
=
0
;
bool
has_weight
=
false
;
if
(
weight_map
.
find
(
i_key
)
!=
weight_map
.
end
())
{
w
=
weight_map
[
i_key
];
has_weight
=
true
;
}
int
index
=
0
;
for
(
int
i
=
0
;
i
<
part_len
;
i
++
)
{
base
=
gb_size_by_discount
-
memory_remaining
[
i
]
+
value
.
size
();
if
(
has_weight
)
weight_base
=
weight_cost
[
i
]
+
w
*
weight_param
;
else
{
weight_base
=
0
;
}
score
[
i
]
-=
a
*
y
*
std
::
pow
(
1.0
*
base
,
y
-
1
)
+
weight_base
;
if
(
score
[
i
]
>
score
[
index
])
index
=
i
;
VLOG
(
2
)
<<
"score"
<<
i
<<
" = "
<<
score
[
i
]
<<
" memory left "
<<
memory_remaining
[
i
];
}
id_map
[
i_key
]
=
index
;
partitions
[
idx
][
index
].
push_back
(
i_key
);
memory_remaining
[
index
]
-=
(
int64_t
)
value
.
size
();
if
(
has_weight
)
weight_cost
[
index
]
+=
w
;
iters
[
next
]
->
Next
();
next
++
;
}
for
(
int
i
=
0
;
i
<
part_len
;
i
++
)
{
if
(
partitions
[
idx
][
i
].
size
()
==
0
)
{
partitions
[
idx
].
erase
(
partitions
[
idx
].
begin
()
+
i
);
i
--
;
part_len
--
;
continue
;
}
VLOG
(
2
)
<<
" partition "
<<
i
<<
" size = "
<<
partitions
[
idx
][
i
].
size
();
for
(
auto
x
:
partitions
[
idx
][
i
])
{
VLOG
(
2
)
<<
"find a id "
<<
x
;
}
}
next_partition
=
0
;
}
void
GraphTable
::
export_partition_files
(
int
idx
,
std
::
string
file_path
)
{
int
part_len
=
partitions
[
idx
].
size
();
if
(
part_len
==
0
)
return
;
if
(
file_path
==
""
)
file_path
=
"."
;
if
(
file_path
[(
int
)
file_path
.
size
()
-
1
]
!=
'/'
)
{
file_path
+=
"/"
;
}
std
::
vector
<
std
::
future
<
int
>>
tasks
;
for
(
int
i
=
0
;
i
<
part_len
;
i
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
&
,
i
,
idx
,
this
]()
->
int
{
std
::
string
output_path
=
file_path
+
"partition_"
+
std
::
to_string
(
i
);
std
::
ofstream
ofs
(
output_path
);
if
(
ofs
.
fail
())
{
VLOG
(
0
)
<<
"creating "
<<
output_path
<<
" failed"
;
return
0
;
}
for
(
auto
x
:
partitions
[
idx
][
i
])
{
auto
str
=
std
::
to_string
(
x
);
ofs
.
write
(
str
.
c_str
(),
str
.
size
());
ofs
.
write
(
"
\n
"
,
1
);
}
ofs
.
close
();
return
0
;
}));
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
}
void
GraphTable
::
clear_graph
(
int
idx
)
{
for
(
auto
p
:
edge_shards
[
idx
])
{
delete
p
;
}
edge_shards
[
idx
].
clear
();
for
(
size_t
i
=
0
;
i
<
shard_num_per_server
;
i
++
)
{
edge_shards
[
idx
].
push_back
(
new
GraphShard
());
}
}
int32_t
GraphTable
::
load_next_partition
(
int
idx
)
{
if
(
next_partition
>=
(
int
)
partitions
[
idx
].
size
())
{
VLOG
(
0
)
<<
"partition iteration is done"
;
return
-
1
;
}
clear_graph
(
idx
);
load_graph_to_memory_from_ssd
(
idx
,
partitions
[
idx
][
next_partition
]);
next_partition
++
;
return
0
;
}
int32_t
GraphTable
::
load_edges_to_ssd
(
const
std
::
string
&
path
,
bool
reverse_edge
,
const
std
::
string
&
edge_type
)
{
int
idx
=
0
;
if
(
edge_type
==
""
)
{
VLOG
(
0
)
<<
"edge_type not specified, loading edges to "
<<
id_to_edge
[
0
]
<<
" part"
;
}
else
{
if
(
edge_to_id
.
find
(
edge_type
)
==
edge_to_id
.
end
())
{
VLOG
(
0
)
<<
"edge_type "
<<
edge_type
<<
" is not defined, nothing will be loaded"
;
return
0
;
}
idx
=
edge_to_id
[
edge_type
];
}
total_memory_cost
=
0
;
auto
paths
=
paddle
::
string
::
split_string
<
std
::
string
>
(
path
,
";"
);
int64_t
count
=
0
;
std
::
string
sample_type
=
"random"
;
for
(
auto
path
:
paths
)
{
std
::
ifstream
file
(
path
);
std
::
string
line
;
while
(
std
::
getline
(
file
,
line
))
{
VLOG
(
0
)
<<
"get a line from file "
<<
line
;
auto
values
=
paddle
::
string
::
split_string
<
std
::
string
>
(
line
,
"
\t
"
);
count
++
;
if
(
values
.
size
()
<
2
)
continue
;
auto
src_id
=
std
::
stoll
(
values
[
0
]);
auto
dist_ids
=
paddle
::
string
::
split_string
<
std
::
string
>
(
values
[
1
],
";"
);
std
::
vector
<
uint64_t
>
dist_data
;
for
(
auto
x
:
dist_ids
)
{
dist_data
.
push_back
(
std
::
stoll
(
x
));
total_memory_cost
+=
sizeof
(
uint64_t
);
}
add_node_to_ssd
(
0
,
idx
,
src_id
,
(
char
*
)
dist_data
.
data
(),
(
int
)(
dist_data
.
size
()
*
sizeof
(
uint64_t
)));
}
}
VLOG
(
0
)
<<
"total memory cost = "
<<
total_memory_cost
<<
" bytes"
;
return
0
;
}
int32_t
GraphTable
::
dump_edges_to_ssd
(
int
idx
)
{
VLOG
(
2
)
<<
"calling dump edges to ssd"
;
std
::
vector
<
std
::
future
<
int64_t
>>
tasks
;
auto
&
shards
=
edge_shards
[
idx
];
for
(
size_t
i
=
0
;
i
<
shards
.
size
();
++
i
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
&
,
i
,
this
]()
->
int64_t
{
int64_t
cost
=
0
;
std
::
vector
<
Node
*>
&
v
=
shards
[
i
]
->
get_bucket
();
for
(
size_t
j
=
0
;
j
<
v
.
size
();
j
++
)
{
std
::
vector
<
uint64_t
>
s
;
for
(
size_t
k
=
0
;
k
<
(
int
)
v
[
j
]
->
get_neighbor_size
();
k
++
)
{
s
.
push_back
(
v
[
j
]
->
get_neighbor_id
(
k
));
}
cost
+=
v
[
j
]
->
get_neighbor_size
()
*
sizeof
(
uint64_t
);
add_node_to_ssd
(
0
,
idx
,
v
[
j
]
->
get_id
(),
(
char
*
)
s
.
data
(),
s
.
size
()
*
sizeof
(
uint64_t
));
}
return
cost
;
}));
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
total_memory_cost
+=
tasks
[
i
].
get
();
return
0
;
}
int32_t
GraphTable
::
make_complementary_graph
(
int
idx
,
int64_t
byte_size
)
{
VLOG
(
0
)
<<
"make_complementary_graph"
;
const
int64_t
fixed_size
=
byte_size
/
8
;
// std::vector<int64_t> edge_array[task_pool_size_];
std
::
vector
<
std
::
unordered_map
<
uint64_t
,
int
>>
count
(
task_pool_size_
);
std
::
vector
<
std
::
future
<
int
>>
tasks
;
auto
&
shards
=
edge_shards
[
idx
];
for
(
size_t
i
=
0
;
i
<
shards
.
size
();
++
i
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
([
&
,
i
,
this
]()
->
int
{
std
::
vector
<
Node
*>
&
v
=
shards
[
i
]
->
get_bucket
();
size_t
ind
=
i
%
this
->
task_pool_size_
;
for
(
size_t
j
=
0
;
j
<
v
.
size
();
j
++
)
{
// size_t location = v[j]->get_id();
for
(
size_t
k
=
0
;
k
<
v
[
j
]
->
get_neighbor_size
();
k
++
)
{
count
[
ind
][
v
[
j
]
->
get_neighbor_id
(
k
)]
++
;
}
}
return
0
;
}));
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
std
::
unordered_map
<
uint64_t
,
int
>
final_count
;
std
::
map
<
int
,
std
::
vector
<
uint64_t
>>
count_to_id
;
std
::
vector
<
uint64_t
>
buffer
;
clear_graph
(
idx
);
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
for
(
auto
&
p
:
count
[
i
])
{
final_count
[
p
.
first
]
=
final_count
[
p
.
first
]
+
p
.
second
;
}
count
[
i
].
clear
();
}
for
(
auto
&
p
:
final_count
)
{
count_to_id
[
p
.
second
].
push_back
(
p
.
first
);
VLOG
(
2
)
<<
p
.
first
<<
" appear "
<<
p
.
second
<<
" times"
;
}
auto
iter
=
count_to_id
.
rbegin
();
while
(
iter
!=
count_to_id
.
rend
()
&&
byte_size
>
0
)
{
for
(
auto
x
:
iter
->
second
)
{
buffer
.
push_back
(
x
);
if
(
buffer
.
size
()
>=
fixed_size
)
{
int64_t
res
=
load_graph_to_memory_from_ssd
(
idx
,
buffer
);
buffer
.
clear
();
byte_size
-=
res
;
}
if
(
byte_size
<=
0
)
break
;
}
iter
++
;
}
if
(
byte_size
>
0
&&
buffer
.
size
()
>
0
)
{
int64_t
res
=
load_graph_to_memory_from_ssd
(
idx
,
buffer
);
byte_size
-=
res
;
}
std
::
string
sample_type
=
"random"
;
for
(
auto
&
shard
:
edge_shards
[
idx
])
{
auto
bucket
=
shard
->
get_bucket
();
for
(
size_t
i
=
0
;
i
<
bucket
.
size
();
i
++
)
{
bucket
[
i
]
->
build_sampler
(
sample_type
);
}
}
return
0
;
}
#endif
/*
int CompleteGraphSampler::run_graph_sampling() {
pthread_rwlock_t *rw_lock = graph_table->rw_lock.get();
pthread_rwlock_rdlock(rw_lock);
std::cout << "in graph sampling" << std::endl;
sample_nodes.clear();
sample_neighbors.clear();
sample_res.clear();
sample_nodes.resize(gpu_num);
sample_neighbors.resize(gpu_num);
sample_res.resize(gpu_num);
std::vector<std::vector<std::vector<paddle::framework::GpuPsGraphNode>>>
sample_nodes_ex(graph_table->task_pool_size_);
std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex(
graph_table->task_pool_size_);
for (int i = 0; i < graph_table->task_pool_size_; i++) {
sample_nodes_ex[i].resize(gpu_num);
sample_neighbors_ex[i].resize(gpu_num);
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < graph_table->shards.size(); ++i) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0;
paddle::framework::GpuPsGraphNode node;
std::vector<Node *> &v =
this->graph_table->shards[i]->get_bucket();
size_t ind = i % this->graph_table->task_pool_size_;
for (size_t j = 0; j < v.size(); j++) {
size_t location = v[j]->get_id() % this->gpu_num;
node.node_id = v[j]->get_id();
node.neighbor_size = v[j]->get_neighbor_size();
node.neighbor_offset =
(int)sample_neighbors_ex[ind][location].size();
sample_nodes_ex[ind][location].emplace_back(node);
for (int k = 0; k < node.neighbor_size; k++)
sample_neighbors_ex[ind][location].push_back(
v[j]->get_neighbor_id(k));
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
tasks.clear();
for (int i = 0; i < gpu_num; i++) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0;
int total_offset = 0;
size_t ind = i % this->graph_table->task_pool_size_;
for (int j = 0; j < this->graph_table->task_pool_size_; j++) {
for (size_t k = 0; k < sample_nodes_ex[j][ind].size(); k++) {
sample_nodes[ind].push_back(sample_nodes_ex[j][ind][k]);
sample_nodes[ind].back().neighbor_offset += total_offset;
}
size_t neighbor_size = sample_neighbors_ex[j][ind].size();
total_offset += neighbor_size;
for (size_t k = 0; k < neighbor_size; k++) {
sample_neighbors[ind].push_back(
sample_neighbors_ex[j][ind][k]);
}
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
for (int i = 0; i < gpu_num; i++) {
sample_res[i].node_list = sample_nodes[i].data();
sample_res[i].neighbor_list = sample_neighbors[i].data();
sample_res[i].node_size = sample_nodes[i].size();
sample_res[i].neighbor_size = sample_neighbors[i].size();
}
pthread_rwlock_unlock(rw_lock);
if (this->status == GraphSamplerStatus::terminating) {
return 0;
}
callback(sample_res);
return 0;
}
void CompleteGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) {
this->gpu_num = gpu_num;
this->graph_table = graph_table;
}
int BasicBfsGraphSampler::run_graph_sampling() {
pthread_rwlock_t *rw_lock = graph_table->rw_lock.get();
pthread_rwlock_rdlock(rw_lock);
while (rounds > 0 && status == GraphSamplerStatus::running) {
for (size_t i = 0; i < sample_neighbors_map.size(); i++) {
sample_neighbors_map[i].clear();
}
sample_neighbors_map.clear();
std::vector<int> nodes_left(graph_table->shards.size(),
node_num_for_each_shard);
std::promise<int> prom;
std::future<int> fut = prom.get_future();
sample_neighbors_map.resize(graph_table->task_pool_size_);
int task_size = 0;
std::vector<std::future<int>> tasks;
int init_size = 0;
//__sync_fetch_and_add
std::function<int(int, int64_t)> bfs = [&, this](int i, int id) -> int {
if (this->status == GraphSamplerStatus::terminating) {
int task_left = __sync_sub_and_fetch(&task_size, 1);
if (task_left == 0) {
prom.set_value(0);
}
return 0;
}
size_t ind = i % this->graph_table->task_pool_size_;
if (nodes_left[i] > 0) {
auto iter = sample_neighbors_map[ind].find(id);
if (iter == sample_neighbors_map[ind].end()) {
Node *node = graph_table->shards[i]->find_node(id);
if (node != NULL) {
nodes_left[i]--;
sample_neighbors_map[ind][id] = std::vector<int64_t>();
iter = sample_neighbors_map[ind].find(id);
size_t edge_fetch_size =
std::min((size_t) this->edge_num_for_each_node,
node->get_neighbor_size());
for (size_t k = 0; k < edge_fetch_size; k++) {
int64_t neighbor_id = node->get_neighbor_id(k);
int node_location = neighbor_id % this->graph_table->shard_num %
this->graph_table->task_pool_size_;
__sync_add_and_fetch(&task_size, 1);
graph_table->_shards_task_pool[node_location]->enqueue(
bfs, neighbor_id % this->graph_table->shard_num, neighbor_id);
iter->second.push_back(neighbor_id);
}
}
}
}
int task_left = __sync_sub_and_fetch(&task_size, 1);
if (task_left == 0) {
prom.set_value(0);
}
return 0;
};
for (size_t i = 0; i < graph_table->shards.size(); ++i) {
std::vector<Node *> &v = graph_table->shards[i]->get_bucket();
if (v.size() > 0) {
int search_size = std::min(init_search_size, (int)v.size());
for (int k = 0; k < search_size; k++) {
init_size++;
__sync_add_and_fetch(&task_size, 1);
int64_t id = v[k]->get_id();
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue(bfs, i, id);
}
} // if
}
if (init_size == 0) {
prom.set_value(0);
}
fut.get();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
VLOG(0) << "BasicBfsGraphSampler finishes the graph searching task";
sample_nodes.clear();
sample_neighbors.clear();
sample_res.clear();
sample_nodes.resize(gpu_num);
sample_neighbors.resize(gpu_num);
sample_res.resize(gpu_num);
std::vector<std::vector<std::vector<paddle::framework::GpuPsGraphNode>>>
sample_nodes_ex(graph_table->task_pool_size_);
std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex(
graph_table->task_pool_size_);
for (int i = 0; i < graph_table->task_pool_size_; i++) {
sample_nodes_ex[i].resize(gpu_num);
sample_neighbors_ex[i].resize(gpu_num);
}
tasks.clear();
for (size_t i = 0; i < (size_t)graph_table->task_pool_size_; ++i) {
tasks.push_back(
graph_table->_shards_task_pool[i]->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) {
return 0;
}
paddle::framework::GpuPsGraphNode node;
auto iter = sample_neighbors_map[i].begin();
size_t ind = i;
for (; iter != sample_neighbors_map[i].end(); iter++) {
size_t location = iter->first % this->gpu_num;
node.node_id = iter->first;
node.neighbor_size = iter->second.size();
node.neighbor_offset =
(int)sample_neighbors_ex[ind][location].size();
sample_nodes_ex[ind][location].emplace_back(node);
for (auto k : iter->second)
sample_neighbors_ex[ind][location].push_back(k);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) {
tasks[i].get();
sample_neighbors_map[i].clear();
}
tasks.clear();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
for (size_t i = 0; i < (size_t)gpu_num; i++) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
int total_offset = 0;
for (int j = 0; j < this->graph_table->task_pool_size_; j++) {
for (size_t k = 0; k < sample_nodes_ex[j][i].size(); k++) {
sample_nodes[i].push_back(sample_nodes_ex[j][i][k]);
sample_nodes[i].back().neighbor_offset += total_offset;
}
size_t neighbor_size = sample_neighbors_ex[j][i].size();
total_offset += neighbor_size;
for (size_t k = 0; k < neighbor_size; k++) {
sample_neighbors[i].push_back(sample_neighbors_ex[j][i][k]);
}
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
for (int i = 0; i < gpu_num; i++) {
sample_res[i].node_list = sample_nodes[i].data();
sample_res[i].neighbor_list = sample_neighbors[i].data();
sample_res[i].node_size = sample_nodes[i].size();
sample_res[i].neighbor_size = sample_neighbors[i].size();
}
pthread_rwlock_unlock(rw_lock);
if (this->status == GraphSamplerStatus::terminating) {
return 0;
}
callback(sample_res);
rounds--;
if (rounds > 0) {
for (int i = 0;
i < interval && this->status == GraphSamplerStatus::running; i++) {
std::this_thread::sleep_for(std::chrono::seconds(1));
}
}
VLOG(0)<<"bfs returning";
}
return 0;
}
void BasicBfsGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) {
this->gpu_num = gpu_num;
this->graph_table = graph_table;
init_search_size = args.size() > 0 ? std::stoi(args[0]) : 10;
node_num_for_each_shard = args.size() > 1 ? std::stoi(args[1]) : 10;
edge_num_for_each_node = args.size() > 2 ? std::stoi(args[2]) : 10;
rounds = args.size() > 3 ? std::stoi(args[3]) : 1;
interval = args.size() > 4 ? std::stoi(args[4]) : 60;
}
#endif
*/
std
::
vector
<
Node
*>
GraphShard
::
get_batch
(
int
start
,
int
end
,
int
step
)
{
if
(
start
<
0
)
start
=
0
;
std
::
vector
<
Node
*>
res
;
for
(
int
pos
=
start
;
pos
<
std
::
min
(
end
,
(
int
)
bucket
.
size
());
pos
+=
step
)
{
res
.
push_back
(
bucket
[
pos
]);
}
return
res
;
}
size_t
GraphShard
::
get_size
()
{
return
bucket
.
size
();
}
int32_t
GraphTable
::
add_comm_edge
(
int
idx
,
uint64_t
src_id
,
uint64_t
dst_id
)
{
size_t
src_shard_id
=
src_id
%
shard_num
;
if
(
src_shard_id
>=
shard_end
||
src_shard_id
<
shard_start
)
{
return
-
1
;
}
size_t
index
=
src_shard_id
-
shard_start
;
edge_shards
[
idx
][
index
]
->
add_graph_node
(
src_id
)
->
build_edges
(
false
);
edge_shards
[
idx
][
index
]
->
add_neighbor
(
src_id
,
dst_id
,
1.0
);
return
0
;
}
int32_t
GraphTable
::
add_graph_node
(
int
idx
,
std
::
vector
<
uint64_t
>
&
id_list
,
std
::
vector
<
bool
>
&
is_weight_list
)
{
auto
&
shards
=
edge_shards
[
idx
];
size_t
node_size
=
id_list
.
size
();
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
bool
>>>
batch
(
task_pool_size_
);
for
(
size_t
i
=
0
;
i
<
node_size
;
i
++
)
{
size_t
shard_id
=
id_list
[
i
]
%
shard_num
;
if
(
shard_id
>=
shard_end
||
shard_id
<
shard_start
)
{
continue
;
}
batch
[
get_thread_pool_index
(
id_list
[
i
])].
push_back
(
{
id_list
[
i
],
i
<
is_weight_list
.
size
()
?
is_weight_list
[
i
]
:
false
});
}
std
::
vector
<
std
::
future
<
int
>>
tasks
;
for
(
size_t
i
=
0
;
i
<
batch
.
size
();
++
i
)
{
if
(
!
batch
[
i
].
size
())
continue
;
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
shards
,
&
batch
,
i
,
this
]()
->
int
{
for
(
auto
&
p
:
batch
[
i
])
{
size_t
index
=
p
.
first
%
this
->
shard_num
-
this
->
shard_start
;
shards
[
index
]
->
add_graph_node
(
p
.
first
)
->
build_edges
(
p
.
second
);
}
return
0
;
}));
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
return
0
;
}
int32_t
GraphTable
::
remove_graph_node
(
int
idx
,
std
::
vector
<
uint64_t
>
&
id_list
)
{
size_t
node_size
=
id_list
.
size
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
batch
(
task_pool_size_
);
for
(
size_t
i
=
0
;
i
<
node_size
;
i
++
)
{
size_t
shard_id
=
id_list
[
i
]
%
shard_num
;
if
(
shard_id
>=
shard_end
||
shard_id
<
shard_start
)
continue
;
batch
[
get_thread_pool_index
(
id_list
[
i
])].
push_back
(
id_list
[
i
]);
}
auto
&
shards
=
edge_shards
[
idx
];
std
::
vector
<
std
::
future
<
int
>>
tasks
;
for
(
size_t
i
=
0
;
i
<
batch
.
size
();
++
i
)
{
if
(
!
batch
[
i
].
size
())
continue
;
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
shards
,
&
batch
,
i
,
this
]()
->
int
{
for
(
auto
&
p
:
batch
[
i
])
{
size_t
index
=
p
%
this
->
shard_num
-
this
->
shard_start
;
shards
[
index
]
->
delete_node
(
p
);
}
return
0
;
}));
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
return
0
;
}
void
GraphShard
::
clear
()
{
for
(
size_t
i
=
0
;
i
<
bucket
.
size
();
i
++
)
{
delete
bucket
[
i
];
}
bucket
.
clear
();
node_location
.
clear
();
}
GraphShard
::~
GraphShard
()
{
clear
();
}
void
GraphShard
::
delete_node
(
uint64_t
id
)
{
auto
iter
=
node_location
.
find
(
id
);
if
(
iter
==
node_location
.
end
())
return
;
int
pos
=
iter
->
second
;
delete
bucket
[
pos
];
if
(
pos
!=
(
int
)
bucket
.
size
()
-
1
)
{
bucket
[
pos
]
=
bucket
.
back
();
node_location
[
bucket
.
back
()
->
get_id
()]
=
pos
;
}
node_location
.
erase
(
id
);
bucket
.
pop_back
();
}
GraphNode
*
GraphShard
::
add_graph_node
(
uint64_t
id
)
{
if
(
node_location
.
find
(
id
)
==
node_location
.
end
())
{
node_location
[
id
]
=
bucket
.
size
();
bucket
.
push_back
(
new
GraphNode
(
id
));
}
return
(
GraphNode
*
)
bucket
[
node_location
[
id
]];
}
GraphNode
*
GraphShard
::
add_graph_node
(
Node
*
node
)
{
auto
id
=
node
->
get_id
();
if
(
node_location
.
find
(
id
)
==
node_location
.
end
())
{
node_location
[
id
]
=
bucket
.
size
();
bucket
.
push_back
(
node
);
}
return
(
GraphNode
*
)
bucket
[
node_location
[
id
]];
}
FeatureNode
*
GraphShard
::
add_feature_node
(
uint64_t
id
,
bool
is_overlap
)
{
if
(
node_location
.
find
(
id
)
==
node_location
.
end
())
{
node_location
[
id
]
=
bucket
.
size
();
bucket
.
push_back
(
new
FeatureNode
(
id
));
return
(
FeatureNode
*
)
bucket
[
node_location
[
id
]];
}
if
(
is_overlap
)
{
return
(
FeatureNode
*
)
bucket
[
node_location
[
id
]];
}
return
NULL
;
}
void
GraphShard
::
add_neighbor
(
uint64_t
id
,
uint64_t
dst_id
,
float
weight
)
{
find_node
(
id
)
->
add_edge
(
dst_id
,
weight
);
}
Node
*
GraphShard
::
find_node
(
uint64_t
id
)
{
auto
iter
=
node_location
.
find
(
id
);
return
iter
==
node_location
.
end
()
?
nullptr
:
bucket
[
iter
->
second
];
}
GraphTable
::~
GraphTable
()
{
for
(
int
i
=
0
;
i
<
(
int
)
edge_shards
.
size
();
i
++
)
{
for
(
auto
p
:
edge_shards
[
i
])
{
delete
p
;
}
edge_shards
[
i
].
clear
();
}
for
(
int
i
=
0
;
i
<
(
int
)
feature_shards
.
size
();
i
++
)
{
for
(
auto
p
:
feature_shards
[
i
])
{
delete
p
;
}
feature_shards
[
i
].
clear
();
}
}
int32_t
GraphTable
::
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
bool
load_edge
=
(
param
[
0
]
==
'e'
);
bool
load_node
=
(
param
[
0
]
==
'n'
);
if
(
load_edge
)
{
bool
reverse_edge
=
(
param
[
1
]
==
'<'
);
std
::
string
edge_type
=
param
.
substr
(
2
);
return
this
->
load_edges
(
path
,
reverse_edge
,
edge_type
);
}
if
(
load_node
)
{
std
::
string
node_type
=
param
.
substr
(
1
);
return
this
->
load_nodes
(
path
,
node_type
);
}
return
0
;
}
std
::
string
GraphTable
::
get_inverse_etype
(
std
::
string
&
etype
)
{
auto
etype_split
=
paddle
::
string
::
split_string
<
std
::
string
>
(
etype
,
"2"
);
std
::
string
res
;
if
((
int
)
etype_split
.
size
()
==
3
)
{
res
=
etype_split
[
2
]
+
"2"
+
etype_split
[
1
]
+
"2"
+
etype_split
[
0
];
}
else
{
res
=
etype_split
[
1
]
+
"2"
+
etype_split
[
0
];
}
return
res
;
}
int32_t
GraphTable
::
load_node_and_edge_file
(
std
::
string
etype
,
std
::
string
ntype
,
std
::
string
epath
,
std
::
string
npath
,
int
part_num
,
bool
reverse
)
{
auto
etypes
=
paddle
::
string
::
split_string
<
std
::
string
>
(
etype
,
","
);
auto
ntypes
=
paddle
::
string
::
split_string
<
std
::
string
>
(
ntype
,
","
);
VLOG
(
0
)
<<
"etypes size: "
<<
etypes
.
size
();
VLOG
(
0
)
<<
"whether reverse: "
<<
reverse
;
std
::
string
delim
=
";"
;
size_t
total_len
=
etypes
.
size
()
+
1
;
// 1 is for node
std
::
vector
<
std
::
future
<
int
>>
tasks
;
for
(
size_t
i
=
0
;
i
<
total_len
;
i
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
([
&
,
i
,
this
]()
->
int
{
if
(
i
<
etypes
.
size
())
{
std
::
string
etype_path
=
epath
+
"/"
+
etypes
[
i
];
auto
etype_path_list
=
paddle
::
framework
::
localfs_list
(
etype_path
);
std
::
string
etype_path_str
;
if
(
part_num
>
0
&&
part_num
<
(
int
)
etype_path_list
.
size
())
{
std
::
vector
<
std
::
string
>
sub_etype_path_list
(
etype_path_list
.
begin
(),
etype_path_list
.
begin
()
+
part_num
);
etype_path_str
=
paddle
::
string
::
join_strings
(
sub_etype_path_list
,
delim
);
}
else
{
etype_path_str
=
paddle
::
string
::
join_strings
(
etype_path_list
,
delim
);
}
this
->
load_edges
(
etype_path_str
,
false
,
etypes
[
i
]);
if
(
reverse
)
{
std
::
string
r_etype
=
get_inverse_etype
(
etypes
[
i
]);
this
->
load_edges
(
etype_path_str
,
true
,
r_etype
);
}
}
else
{
auto
npath_list
=
paddle
::
framework
::
localfs_list
(
npath
);
std
::
string
npath_str
;
if
(
part_num
>
0
&&
part_num
<
(
int
)
npath_list
.
size
())
{
std
::
vector
<
std
::
string
>
sub_npath_list
(
npath_list
.
begin
(),
npath_list
.
begin
()
+
part_num
);
npath_str
=
paddle
::
string
::
join_strings
(
sub_npath_list
,
delim
);
}
else
{
npath_str
=
paddle
::
string
::
join_strings
(
npath_list
,
delim
);
}
if
(
ntypes
.
size
()
==
0
)
{
VLOG
(
0
)
<<
"node_type not specified, nothing will be loaded "
;
return
0
;
}
if
(
FLAGS_graph_load_in_parallel
)
{
this
->
load_nodes
(
npath_str
,
""
);
}
else
{
for
(
size_t
j
=
0
;
j
<
ntypes
.
size
();
j
++
)
{
this
->
load_nodes
(
npath_str
,
ntypes
[
j
]);
}
}
}
return
0
;
}));
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
return
0
;
}
int32_t
GraphTable
::
get_nodes_ids_by_ranges
(
int
type_id
,
int
idx
,
std
::
vector
<
std
::
pair
<
int
,
int
>>
ranges
,
std
::
vector
<
uint64_t
>
&
res
)
{
std
::
mutex
mutex
;
int
start
=
0
,
end
,
index
=
0
,
total_size
=
0
;
res
.
clear
();
auto
&
shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
for
(
size_t
i
=
0
;
i
<
shards
.
size
()
&&
index
<
(
int
)
ranges
.
size
();
i
++
)
{
end
=
total_size
+
shards
[
i
]
->
get_size
();
start
=
total_size
;
while
(
start
<
end
&&
index
<
(
int
)
ranges
.
size
())
{
if
(
ranges
[
index
].
second
<=
start
)
index
++
;
else
if
(
ranges
[
index
].
first
>=
end
)
{
break
;
}
else
{
int
first
=
std
::
max
(
ranges
[
index
].
first
,
start
);
int
second
=
std
::
min
(
ranges
[
index
].
second
,
end
);
start
=
second
;
first
-=
total_size
;
second
-=
total_size
;
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
&
shards
,
this
,
first
,
second
,
i
,
&
res
,
&
mutex
]()
->
size_t
{
std
::
vector
<
uint64_t
>
keys
;
shards
[
i
]
->
get_ids_by_range
(
first
,
second
,
&
keys
);
size_t
num
=
keys
.
size
();
mutex
.
lock
();
res
.
reserve
(
res
.
size
()
+
num
);
for
(
auto
&
id
:
keys
)
{
res
.
push_back
(
id
);
std
::
swap
(
res
[
rand
()
%
res
.
size
()],
res
[(
int
)
res
.
size
()
-
1
]);
}
mutex
.
unlock
();
return
num
;
}));
}
}
total_size
+=
shards
[
i
]
->
get_size
();
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
{
tasks
[
i
].
get
();
}
return
0
;
}
std
::
pair
<
uint64_t
,
uint64_t
>
GraphTable
::
parse_node_file
(
const
std
::
string
&
path
,
const
std
::
string
&
node_type
,
int
idx
)
{
std
::
ifstream
file
(
path
);
std
::
string
line
;
uint64_t
local_count
=
0
;
uint64_t
local_valid_count
=
0
;
int
num
=
0
;
std
::
vector
<
paddle
::
string
::
str_ptr
>
vals
;
size_t
n
=
node_type
.
length
();
while
(
std
::
getline
(
file
,
line
))
{
if
(
strncmp
(
line
.
c_str
(),
node_type
.
c_str
(),
n
)
!=
0
)
{
continue
;
}
vals
.
clear
();
num
=
paddle
::
string
::
split_string_ptr
(
line
.
c_str
()
+
n
+
1
,
line
.
length
()
-
n
-
1
,
'\t'
,
&
vals
);
if
(
num
==
0
)
{
continue
;
}
uint64_t
id
=
std
::
strtoul
(
vals
[
0
].
ptr
,
NULL
,
10
);
size_t
shard_id
=
id
%
shard_num
;
if
(
shard_id
>=
shard_end
||
shard_id
<
shard_start
)
{
VLOG
(
4
)
<<
"will not load "
<<
id
<<
" from "
<<
path
<<
", please check id distribution"
;
continue
;
}
local_count
++
;
size_t
index
=
shard_id
-
shard_start
;
auto
node
=
feature_shards
[
idx
][
index
]
->
add_feature_node
(
id
,
false
);
if
(
node
!=
NULL
)
{
node
->
set_feature_size
(
feat_name
[
idx
].
size
());
for
(
int
i
=
1
;
i
<
num
;
++
i
)
{
auto
&
v
=
vals
[
i
];
parse_feature
(
idx
,
v
.
ptr
,
v
.
len
,
node
);
}
}
local_valid_count
++
;
}
VLOG
(
2
)
<<
"node_type["
<<
node_type
<<
"] loads "
<<
local_count
<<
" nodes from filepath->"
<<
path
;
return
{
local_count
,
local_valid_count
};
}
std
::
pair
<
uint64_t
,
uint64_t
>
GraphTable
::
parse_node_file
(
const
std
::
string
&
path
)
{
std
::
ifstream
file
(
path
);
std
::
string
line
;
uint64_t
local_count
=
0
;
uint64_t
local_valid_count
=
0
;
int
idx
=
0
;
auto
path_split
=
paddle
::
string
::
split_string
<
std
::
string
>
(
path
,
"/"
);
auto
path_name
=
path_split
[
path_split
.
size
()
-
1
];
int
num
=
0
;
std
::
vector
<
paddle
::
string
::
str_ptr
>
vals
;
while
(
std
::
getline
(
file
,
line
))
{
vals
.
clear
();
num
=
paddle
::
string
::
split_string_ptr
(
line
.
c_str
(),
line
.
length
(),
'\t'
,
&
vals
);
if
(
vals
.
empty
())
{
continue
;
}
std
::
string
parse_node_type
=
vals
[
0
].
to_string
();
auto
it
=
feature_to_id
.
find
(
parse_node_type
);
if
(
it
==
feature_to_id
.
end
())
{
VLOG
(
0
)
<<
parse_node_type
<<
"type error, please check"
;
continue
;
}
idx
=
it
->
second
;
uint64_t
id
=
std
::
strtoul
(
vals
[
1
].
ptr
,
NULL
,
10
);
size_t
shard_id
=
id
%
shard_num
;
if
(
shard_id
>=
shard_end
||
shard_id
<
shard_start
)
{
VLOG
(
4
)
<<
"will not load "
<<
id
<<
" from "
<<
path
<<
", please check id distribution"
;
continue
;
}
local_count
++
;
size_t
index
=
shard_id
-
shard_start
;
auto
node
=
feature_shards
[
idx
][
index
]
->
add_feature_node
(
id
,
false
);
if
(
node
!=
NULL
)
{
for
(
int
i
=
2
;
i
<
num
;
++
i
)
{
auto
&
v
=
vals
[
i
];
parse_feature
(
idx
,
v
.
ptr
,
v
.
len
,
node
);
}
}
local_valid_count
++
;
}
VLOG
(
2
)
<<
local_valid_count
<<
"/"
<<
local_count
<<
" nodes from filepath->"
<<
path
;
return
{
local_count
,
local_valid_count
};
}
// TODO opt load all node_types in once reading
int32_t
GraphTable
::
load_nodes
(
const
std
::
string
&
path
,
std
::
string
node_type
)
{
auto
paths
=
paddle
::
string
::
split_string
<
std
::
string
>
(
path
,
";"
);
uint64_t
count
=
0
;
uint64_t
valid_count
=
0
;
int
idx
=
0
;
if
(
FLAGS_graph_load_in_parallel
)
{
if
(
node_type
==
""
)
{
VLOG
(
0
)
<<
"Begin GraphTable::load_nodes(), will load all node_type once"
;
}
std
::
vector
<
std
::
future
<
std
::
pair
<
uint64_t
,
uint64_t
>>>
tasks
;
for
(
size_t
i
=
0
;
i
<
paths
.
size
();
i
++
)
{
tasks
.
push_back
(
load_node_edge_task_pool
->
enqueue
(
[
&
,
i
,
this
]()
->
std
::
pair
<
uint64_t
,
uint64_t
>
{
return
parse_node_file
(
paths
[
i
]);
}));
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
{
auto
res
=
tasks
[
i
].
get
();
count
+=
res
.
first
;
valid_count
+=
res
.
second
;
}
}
else
{
VLOG
(
0
)
<<
"Begin GraphTable::load_nodes() node_type["
<<
node_type
<<
"]"
;
if
(
node_type
==
""
)
{
VLOG
(
0
)
<<
"node_type not specified, loading edges to "
<<
id_to_feature
[
0
]
<<
" part"
;
}
else
{
if
(
feature_to_id
.
find
(
node_type
)
==
feature_to_id
.
end
())
{
VLOG
(
0
)
<<
"node_type "
<<
node_type
<<
" is not defined, nothing will be loaded"
;
return
0
;
}
idx
=
feature_to_id
[
node_type
];
}
for
(
auto
path
:
paths
)
{
VLOG
(
2
)
<<
"Begin GraphTable::load_nodes(), path["
<<
path
<<
"]"
;
auto
res
=
parse_node_file
(
path
,
node_type
,
idx
);
count
+=
res
.
first
;
valid_count
+=
res
.
second
;
}
}
VLOG
(
0
)
<<
valid_count
<<
"/"
<<
count
<<
" nodes in node_type[ "
<<
node_type
<<
"] are loaded successfully!"
;
return
0
;
}
int32_t
GraphTable
::
build_sampler
(
int
idx
,
std
::
string
sample_type
)
{
for
(
auto
&
shard
:
edge_shards
[
idx
])
{
auto
bucket
=
shard
->
get_bucket
();
for
(
size_t
i
=
0
;
i
<
bucket
.
size
();
i
++
)
{
bucket
[
i
]
->
build_sampler
(
sample_type
);
}
}
return
0
;
}
std
::
pair
<
uint64_t
,
uint64_t
>
GraphTable
::
parse_edge_file
(
const
std
::
string
&
path
,
int
idx
,
bool
reverse
)
{
std
::
string
sample_type
=
"random"
;
bool
is_weighted
=
false
;
std
::
ifstream
file
(
path
);
std
::
string
line
;
uint64_t
local_count
=
0
;
uint64_t
local_valid_count
=
0
;
uint64_t
part_num
=
0
;
if
(
FLAGS_graph_load_in_parallel
)
{
auto
path_split
=
paddle
::
string
::
split_string
<
std
::
string
>
(
path
,
"/"
);
auto
part_name_split
=
paddle
::
string
::
split_string
<
std
::
string
>
(
path_split
[
path_split
.
size
()
-
1
],
"-"
);
part_num
=
std
::
stoull
(
part_name_split
[
part_name_split
.
size
()
-
1
]);
}
while
(
std
::
getline
(
file
,
line
))
{
size_t
start
=
line
.
find_first_of
(
'\t'
);
if
(
start
==
std
::
string
::
npos
)
continue
;
local_count
++
;
uint64_t
src_id
=
std
::
stoull
(
&
line
[
0
]);
uint64_t
dst_id
=
std
::
stoull
(
&
line
[
start
+
1
]);
if
(
reverse
)
{
std
::
swap
(
src_id
,
dst_id
);
}
size_t
src_shard_id
=
src_id
%
shard_num
;
if
(
FLAGS_graph_load_in_parallel
)
{
if
(
src_shard_id
!=
(
part_num
%
shard_num
))
{
continue
;
}
}
float
weight
=
1
;
size_t
last
=
line
.
find_last_of
(
'\t'
);
if
(
start
!=
last
)
{
weight
=
std
::
stof
(
&
line
[
last
+
1
]);
sample_type
=
"weighted"
;
is_weighted
=
true
;
}
if
(
src_shard_id
>=
shard_end
||
src_shard_id
<
shard_start
)
{
VLOG
(
4
)
<<
"will not load "
<<
src_id
<<
" from "
<<
path
<<
", please check id distribution"
;
continue
;
}
size_t
index
=
src_shard_id
-
shard_start
;
auto
node
=
edge_shards
[
idx
][
index
]
->
add_graph_node
(
src_id
);
if
(
node
!=
NULL
)
{
node
->
build_edges
(
is_weighted
);
node
->
add_edge
(
dst_id
,
weight
);
}
local_valid_count
++
;
}
VLOG
(
2
)
<<
local_count
<<
" edges are loaded from filepath->"
<<
path
;
return
{
local_count
,
local_valid_count
};
}
int32_t
GraphTable
::
load_edges
(
const
std
::
string
&
path
,
bool
reverse_edge
,
const
std
::
string
&
edge_type
)
{
#ifdef PADDLE_WITH_HETERPS
if
(
search_level
==
2
)
total_memory_cost
=
0
;
const
uint64_t
fixed_load_edges
=
1000000
;
#endif
int
idx
=
0
;
if
(
edge_type
==
""
)
{
VLOG
(
0
)
<<
"edge_type not specified, loading edges to "
<<
id_to_edge
[
0
]
<<
" part"
;
}
else
{
if
(
edge_to_id
.
find
(
edge_type
)
==
edge_to_id
.
end
())
{
VLOG
(
0
)
<<
"edge_type "
<<
edge_type
<<
" is not defined, nothing will be loaded"
;
return
0
;
}
idx
=
edge_to_id
[
edge_type
];
}
auto
paths
=
paddle
::
string
::
split_string
<
std
::
string
>
(
path
,
";"
);
uint64_t
count
=
0
;
uint64_t
valid_count
=
0
;
VLOG
(
0
)
<<
"Begin GraphTable::load_edges() edge_type["
<<
edge_type
<<
"]"
;
if
(
FLAGS_graph_load_in_parallel
)
{
std
::
vector
<
std
::
future
<
std
::
pair
<
uint64_t
,
uint64_t
>>>
tasks
;
for
(
int
i
=
0
;
i
<
paths
.
size
();
i
++
)
{
tasks
.
push_back
(
load_node_edge_task_pool
->
enqueue
(
[
&
,
i
,
idx
,
this
]()
->
std
::
pair
<
uint64_t
,
uint64_t
>
{
return
parse_edge_file
(
paths
[
i
],
idx
,
reverse_edge
);
}));
}
for
(
int
j
=
0
;
j
<
(
int
)
tasks
.
size
();
j
++
)
{
auto
res
=
tasks
[
j
].
get
();
count
+=
res
.
first
;
valid_count
+=
res
.
second
;
}
}
else
{
for
(
auto
path
:
paths
)
{
auto
res
=
parse_edge_file
(
path
,
idx
,
reverse_edge
);
count
+=
res
.
first
;
valid_count
+=
res
.
second
;
}
}
VLOG
(
0
)
<<
valid_count
<<
"/"
<<
count
<<
" edge_type["
<<
edge_type
<<
"] edges are loaded successfully"
;
#ifdef PADDLE_WITH_HETERPS
if
(
search_level
==
2
)
{
if
(
count
>
0
)
{
dump_edges_to_ssd
(
idx
);
VLOG
(
0
)
<<
"dumping edges to ssd, edge count is reset to 0"
;
clear_graph
(
idx
);
count
=
0
;
}
return
0
;
}
#endif
if
(
!
build_sampler_on_cpu
)
{
// To reduce memory overhead, CPU samplers won't be created in gpugraph.
// In order not to affect the sampler function of other scenario,
// this optimization is only performed in load_edges function.
VLOG
(
0
)
<<
"run in gpugraph mode!"
;
}
else
{
std
::
string
sample_type
=
"random"
;
VLOG
(
0
)
<<
"build sampler ... "
;
for
(
auto
&
shard
:
edge_shards
[
idx
])
{
auto
bucket
=
shard
->
get_bucket
();
for
(
size_t
i
=
0
;
i
<
bucket
.
size
();
i
++
)
{
bucket
[
i
]
->
build_sampler
(
sample_type
);
}
}
}
return
0
;
}
Node
*
GraphTable
::
find_node
(
int
type_id
,
uint64_t
id
)
{
size_t
shard_id
=
id
%
shard_num
;
if
(
shard_id
>=
shard_end
||
shard_id
<
shard_start
)
{
return
nullptr
;
}
Node
*
node
=
nullptr
;
size_t
index
=
shard_id
-
shard_start
;
auto
&
search_shards
=
type_id
==
0
?
edge_shards
:
feature_shards
;
for
(
auto
&
search_shard
:
search_shards
)
{
PADDLE_ENFORCE_NOT_NULL
(
search_shard
[
index
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"search_shard[%d] should not be null."
,
index
));
node
=
search_shard
[
index
]
->
find_node
(
id
);
if
(
node
!=
nullptr
)
{
break
;
}
}
return
node
;
}
Node
*
GraphTable
::
find_node
(
int
type_id
,
int
idx
,
uint64_t
id
)
{
size_t
shard_id
=
id
%
shard_num
;
if
(
shard_id
>=
shard_end
||
shard_id
<
shard_start
)
{
return
nullptr
;
}
size_t
index
=
shard_id
-
shard_start
;
auto
&
search_shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
PADDLE_ENFORCE_NOT_NULL
(
search_shards
[
index
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"search_shard[%d] should not be null."
,
index
));
Node
*
node
=
search_shards
[
index
]
->
find_node
(
id
);
return
node
;
}
uint32_t
GraphTable
::
get_thread_pool_index
(
uint64_t
node_id
)
{
return
node_id
%
shard_num
%
shard_num_per_server
%
task_pool_size_
;
}
uint32_t
GraphTable
::
get_thread_pool_index_by_shard_index
(
uint64_t
shard_index
)
{
return
shard_index
%
shard_num_per_server
%
task_pool_size_
;
}
int32_t
GraphTable
::
clear_nodes
(
int
type_id
,
int
idx
)
{
auto
&
search_shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
for
(
size_t
i
=
0
;
i
<
search_shards
.
size
();
i
++
)
{
search_shards
[
i
]
->
clear
();
}
return
0
;
}
int32_t
GraphTable
::
random_sample_nodes
(
int
type_id
,
int
idx
,
int
sample_size
,
std
::
unique_ptr
<
char
[]
>
&
buffer
,
int
&
actual_size
)
{
int
total_size
=
0
;
auto
&
shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
for
(
int
i
=
0
;
i
<
(
int
)
shards
.
size
();
i
++
)
{
total_size
+=
shards
[
i
]
->
get_size
();
}
if
(
sample_size
>
total_size
)
sample_size
=
total_size
;
int
range_num
=
random_sample_nodes_ranges
;
if
(
range_num
>
sample_size
)
range_num
=
sample_size
;
if
(
sample_size
==
0
||
range_num
==
0
)
return
0
;
std
::
vector
<
int
>
ranges_len
,
ranges_pos
;
int
remain
=
sample_size
,
last_pos
=
-
1
,
num
;
std
::
set
<
int
>
separator_set
;
for
(
int
i
=
0
;
i
<
range_num
-
1
;
i
++
)
{
while
(
separator_set
.
find
(
num
=
rand
()
%
(
sample_size
-
1
))
!=
separator_set
.
end
())
;
separator_set
.
insert
(
num
);
}
for
(
auto
p
:
separator_set
)
{
ranges_len
.
push_back
(
p
-
last_pos
);
last_pos
=
p
;
}
ranges_len
.
push_back
(
sample_size
-
1
-
last_pos
);
remain
=
total_size
-
sample_size
+
range_num
;
separator_set
.
clear
();
for
(
int
i
=
0
;
i
<
range_num
;
i
++
)
{
while
(
separator_set
.
find
(
num
=
rand
()
%
remain
)
!=
separator_set
.
end
())
;
separator_set
.
insert
(
num
);
}
int
used
=
0
,
index
=
0
;
last_pos
=
-
1
;
for
(
auto
p
:
separator_set
)
{
used
+=
p
-
last_pos
-
1
;
last_pos
=
p
;
ranges_pos
.
push_back
(
used
);
used
+=
ranges_len
[
index
++
];
}
std
::
vector
<
std
::
pair
<
int
,
int
>>
first_half
,
second_half
;
int
start_index
=
rand
()
%
total_size
;
for
(
size_t
i
=
0
;
i
<
ranges_len
.
size
()
&&
i
<
ranges_pos
.
size
();
i
++
)
{
if
(
ranges_pos
[
i
]
+
ranges_len
[
i
]
-
1
+
start_index
<
total_size
)
first_half
.
push_back
({
ranges_pos
[
i
]
+
start_index
,
ranges_pos
[
i
]
+
ranges_len
[
i
]
+
start_index
});
else
if
(
ranges_pos
[
i
]
+
start_index
>=
total_size
)
{
second_half
.
push_back
(
{
ranges_pos
[
i
]
+
start_index
-
total_size
,
ranges_pos
[
i
]
+
ranges_len
[
i
]
+
start_index
-
total_size
});
}
else
{
first_half
.
push_back
({
ranges_pos
[
i
]
+
start_index
,
total_size
});
second_half
.
push_back
(
{
0
,
ranges_pos
[
i
]
+
ranges_len
[
i
]
+
start_index
-
total_size
});
}
}
for
(
auto
&
pair
:
first_half
)
second_half
.
push_back
(
pair
);
std
::
vector
<
uint64_t
>
res
;
get_nodes_ids_by_ranges
(
type_id
,
idx
,
second_half
,
res
);
actual_size
=
res
.
size
()
*
sizeof
(
uint64_t
);
buffer
.
reset
(
new
char
[
actual_size
]);
char
*
pointer
=
buffer
.
get
();
memcpy
(
pointer
,
res
.
data
(),
actual_size
);
return
0
;
}
int32_t
GraphTable
::
random_sample_neighbors
(
int
idx
,
uint64_t
*
node_ids
,
int
sample_size
,
std
::
vector
<
std
::
shared_ptr
<
char
>>
&
buffers
,
std
::
vector
<
int
>
&
actual_sizes
,
bool
need_weight
)
{
size_t
node_num
=
buffers
.
size
();
std
::
function
<
void
(
char
*
)
>
char_del
=
[](
char
*
c
)
{
delete
[]
c
;
};
std
::
vector
<
std
::
future
<
int
>>
tasks
;
std
::
vector
<
std
::
vector
<
uint32_t
>>
seq_id
(
task_pool_size_
);
std
::
vector
<
std
::
vector
<
SampleKey
>>
id_list
(
task_pool_size_
);
size_t
index
;
for
(
size_t
idy
=
0
;
idy
<
node_num
;
++
idy
)
{
index
=
get_thread_pool_index
(
node_ids
[
idy
]);
seq_id
[
index
].
emplace_back
(
idy
);
id_list
[
index
].
emplace_back
(
idx
,
node_ids
[
idy
],
sample_size
,
need_weight
);
}
for
(
int
i
=
0
;
i
<
(
int
)
seq_id
.
size
();
i
++
)
{
if
(
seq_id
[
i
].
size
()
==
0
)
continue
;
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
,
i
,
this
]()
->
int
{
uint64_t
node_id
;
std
::
vector
<
std
::
pair
<
SampleKey
,
SampleResult
>>
r
;
LRUResponse
response
=
LRUResponse
::
blocked
;
if
(
use_cache
)
{
response
=
scaled_lru
->
query
(
i
,
id_list
[
i
].
data
(),
id_list
[
i
].
size
(),
r
);
}
int
index
=
0
;
std
::
vector
<
SampleResult
>
sample_res
;
std
::
vector
<
SampleKey
>
sample_keys
;
auto
&
rng
=
_shards_task_rng_pool
[
i
];
for
(
size_t
k
=
0
;
k
<
id_list
[
i
].
size
();
k
++
)
{
if
(
index
<
(
int
)
r
.
size
()
&&
r
[
index
].
first
.
node_key
==
id_list
[
i
][
k
].
node_key
)
{
int
idy
=
seq_id
[
i
][
k
];
actual_sizes
[
idy
]
=
r
[
index
].
second
.
actual_size
;
buffers
[
idy
]
=
r
[
index
].
second
.
buffer
;
index
++
;
}
else
{
node_id
=
id_list
[
i
][
k
].
node_key
;
Node
*
node
=
find_node
(
0
,
idx
,
node_id
);
int
idy
=
seq_id
[
i
][
k
];
int
&
actual_size
=
actual_sizes
[
idy
];
if
(
node
==
nullptr
)
{
#ifdef PADDLE_WITH_HETERPS
if
(
search_level
==
2
)
{
VLOG
(
2
)
<<
"enter sample from ssd for node_id "
<<
node_id
;
char
*
buffer_addr
=
random_sample_neighbor_from_ssd
(
idx
,
node_id
,
sample_size
,
rng
,
actual_size
);
if
(
actual_size
!=
0
)
{
std
::
shared_ptr
<
char
>
&
buffer
=
buffers
[
idy
];
buffer
.
reset
(
buffer_addr
,
char_del
);
}
VLOG
(
2
)
<<
"actual sampled size from ssd = "
<<
actual_sizes
[
idy
];
continue
;
}
#endif
actual_size
=
0
;
continue
;
}
std
::
shared_ptr
<
char
>
&
buffer
=
buffers
[
idy
];
std
::
vector
<
int
>
res
=
node
->
sample_k
(
sample_size
,
rng
);
actual_size
=
res
.
size
()
*
(
need_weight
?
(
Node
::
id_size
+
Node
::
weight_size
)
:
Node
::
id_size
);
int
offset
=
0
;
uint64_t
id
;
float
weight
;
char
*
buffer_addr
=
new
char
[
actual_size
];
if
(
response
==
LRUResponse
::
ok
)
{
sample_keys
.
emplace_back
(
idx
,
node_id
,
sample_size
,
need_weight
);
sample_res
.
emplace_back
(
actual_size
,
buffer_addr
);
buffer
=
sample_res
.
back
().
buffer
;
}
else
{
buffer
.
reset
(
buffer_addr
,
char_del
);
}
for
(
int
&
x
:
res
)
{
id
=
node
->
get_neighbor_id
(
x
);
memcpy
(
buffer_addr
+
offset
,
&
id
,
Node
::
id_size
);
offset
+=
Node
::
id_size
;
if
(
need_weight
)
{
weight
=
node
->
get_neighbor_weight
(
x
);
memcpy
(
buffer_addr
+
offset
,
&
weight
,
Node
::
weight_size
);
offset
+=
Node
::
weight_size
;
}
}
}
}
if
(
sample_res
.
size
())
{
scaled_lru
->
insert
(
i
,
sample_keys
.
data
(),
sample_res
.
data
(),
sample_keys
.
size
());
}
return
0
;
}));
}
for
(
auto
&
t
:
tasks
)
{
t
.
get
();
}
return
0
;
}
int32_t
GraphTable
::
get_node_feat
(
int
idx
,
const
std
::
vector
<
uint64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
)
{
size_t
node_num
=
node_ids
.
size
();
std
::
vector
<
std
::
future
<
int
>>
tasks
;
for
(
size_t
idy
=
0
;
idy
<
node_num
;
++
idy
)
{
uint64_t
node_id
=
node_ids
[
idy
];
tasks
.
push_back
(
_shards_task_pool
[
get_thread_pool_index
(
node_id
)]
->
enqueue
(
[
&
,
idx
,
idy
,
node_id
]()
->
int
{
Node
*
node
=
find_node
(
1
,
idx
,
node_id
);
if
(
node
==
nullptr
)
{
return
0
;
}
for
(
int
feat_idx
=
0
;
feat_idx
<
(
int
)
feature_names
.
size
();
++
feat_idx
)
{
const
std
::
string
&
feature_name
=
feature_names
[
feat_idx
];
if
(
feat_id_map
[
idx
].
find
(
feature_name
)
!=
feat_id_map
[
idx
].
end
())
{
// res[feat_idx][idx] =
// node->get_feature(feat_id_map[feature_name]);
auto
feat
=
node
->
get_feature
(
feat_id_map
[
idx
][
feature_name
]);
res
[
feat_idx
][
idy
]
=
feat
;
}
}
return
0
;
}));
}
for
(
size_t
idy
=
0
;
idy
<
node_num
;
++
idy
)
{
tasks
[
idy
].
get
();
}
return
0
;
}
int32_t
GraphTable
::
set_node_feat
(
int
idx
,
const
std
::
vector
<
uint64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
)
{
size_t
node_num
=
node_ids
.
size
();
std
::
vector
<
std
::
future
<
int
>>
tasks
;
for
(
size_t
idy
=
0
;
idy
<
node_num
;
++
idy
)
{
uint64_t
node_id
=
node_ids
[
idy
];
tasks
.
push_back
(
_shards_task_pool
[
get_thread_pool_index
(
node_id
)]
->
enqueue
(
[
&
,
idx
,
idy
,
node_id
]()
->
int
{
size_t
index
=
node_id
%
this
->
shard_num
-
this
->
shard_start
;
auto
node
=
feature_shards
[
idx
][
index
]
->
add_feature_node
(
node_id
);
node
->
set_feature_size
(
this
->
feat_name
[
idx
].
size
());
for
(
int
feat_idx
=
0
;
feat_idx
<
(
int
)
feature_names
.
size
();
++
feat_idx
)
{
const
std
::
string
&
feature_name
=
feature_names
[
feat_idx
];
if
(
feat_id_map
[
idx
].
find
(
feature_name
)
!=
feat_id_map
[
idx
].
end
())
{
node
->
set_feature
(
feat_id_map
[
idx
][
feature_name
],
res
[
feat_idx
][
idy
]);
}
}
return
0
;
}));
}
for
(
size_t
idy
=
0
;
idy
<
node_num
;
++
idy
)
{
tasks
[
idy
].
get
();
}
return
0
;
}
void
string_vector_2_string
(
std
::
vector
<
std
::
string
>::
iterator
strs_begin
,
std
::
vector
<
std
::
string
>::
iterator
strs_end
,
char
delim
,
std
::
string
*
output
)
{
size_t
i
=
0
;
for
(
std
::
vector
<
std
::
string
>::
iterator
iter
=
strs_begin
;
iter
!=
strs_end
;
++
iter
)
{
if
(
i
>
0
)
{
*
output
+=
delim
;
}
*
output
+=
*
iter
;
++
i
;
}
}
void
string_vector_2_string
(
std
::
vector
<
paddle
::
string
::
str_ptr
>::
iterator
strs_begin
,
std
::
vector
<
paddle
::
string
::
str_ptr
>::
iterator
strs_end
,
char
delim
,
std
::
string
*
output
)
{
size_t
i
=
0
;
for
(
auto
iter
=
strs_begin
;
iter
!=
strs_end
;
++
iter
)
{
if
(
i
>
0
)
{
output
->
append
(
&
delim
,
1
);
}
output
->
append
((
*
iter
).
ptr
,
(
*
iter
).
len
);
++
i
;
}
}
int
GraphTable
::
parse_feature
(
int
idx
,
const
char
*
feat_str
,
size_t
len
,
FeatureNode
*
node
)
{
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
// "")
thread_local
std
::
vector
<
paddle
::
string
::
str_ptr
>
fields
;
fields
.
clear
();
const
char
c
=
feature_separator_
.
at
(
0
);
paddle
::
string
::
split_string_ptr
(
feat_str
,
len
,
c
,
&
fields
);
std
::
string
name
=
fields
[
0
].
to_string
();
auto
it
=
feat_id_map
[
idx
].
find
(
name
);
if
(
it
!=
feat_id_map
[
idx
].
end
())
{
int32_t
id
=
it
->
second
;
std
::
string
*
fea_ptr
=
node
->
mutable_feature
(
id
);
std
::
string
dtype
=
this
->
feat_dtype
[
idx
][
id
];
if
(
dtype
==
"feasign"
)
{
// string_vector_2_string(fields.begin() + 1, fields.end(), ' ',
// fea_ptr);
FeatureNode
::
parse_value_to_bytes
<
uint64_t
>
(
fields
.
begin
()
+
1
,
fields
.
end
(),
fea_ptr
);
return
0
;
}
else
if
(
dtype
==
"string"
)
{
string_vector_2_string
(
fields
.
begin
()
+
1
,
fields
.
end
(),
' '
,
fea_ptr
);
return
0
;
}
else
if
(
dtype
==
"float32"
)
{
FeatureNode
::
parse_value_to_bytes
<
float
>
(
fields
.
begin
()
+
1
,
fields
.
end
(),
fea_ptr
);
return
0
;
}
else
if
(
dtype
==
"float64"
)
{
FeatureNode
::
parse_value_to_bytes
<
double
>
(
fields
.
begin
()
+
1
,
fields
.
end
(),
fea_ptr
);
return
0
;
}
else
if
(
dtype
==
"int32"
)
{
FeatureNode
::
parse_value_to_bytes
<
int32_t
>
(
fields
.
begin
()
+
1
,
fields
.
end
(),
fea_ptr
);
return
0
;
}
else
if
(
dtype
==
"int64"
)
{
FeatureNode
::
parse_value_to_bytes
<
uint64_t
>
(
fields
.
begin
()
+
1
,
fields
.
end
(),
fea_ptr
);
return
0
;
}
}
else
{
VLOG
(
2
)
<<
"feature_name["
<<
name
<<
"] is not in feat_id_map, ntype_id["
<<
idx
<<
"] feat_id_map_size["
<<
feat_id_map
.
size
()
<<
"]"
;
}
return
-
1
;
}
// thread safe shard vector merge
class
MergeShardVector
{
public:
MergeShardVector
(
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
,
int
slice_num
)
{
_slice_num
=
slice_num
;
_shard_keys
=
output
;
_shard_keys
->
resize
(
slice_num
);
_mutexs
=
new
std
::
mutex
[
slice_num
];
}
~
MergeShardVector
()
{
if
(
_mutexs
!=
nullptr
)
{
delete
[]
_mutexs
;
_mutexs
=
nullptr
;
}
}
// merge shard keys
void
merge
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>
&
shard_keys
)
{
// add to shard
for
(
int
shard_id
=
0
;
shard_id
<
_slice_num
;
++
shard_id
)
{
auto
&
dest
=
(
*
_shard_keys
)[
shard_id
];
auto
&
src
=
shard_keys
[
shard_id
];
_mutexs
[
shard_id
].
lock
();
dest
.
insert
(
dest
.
end
(),
src
.
begin
(),
src
.
end
());
_mutexs
[
shard_id
].
unlock
();
}
}
private:
int
_slice_num
=
0
;
std
::
mutex
*
_mutexs
=
nullptr
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
_shard_keys
;
};
int
GraphTable
::
get_all_id
(
int
type_id
,
int
slice_num
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
)
{
MergeShardVector
shard_merge
(
output
,
slice_num
);
auto
&
search_shards
=
type_id
==
0
?
edge_shards
:
feature_shards
;
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
for
(
int
idx
=
0
;
idx
<
search_shards
.
size
();
idx
++
)
{
for
(
int
j
=
0
;
j
<
search_shards
[
idx
].
size
();
j
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
j
%
task_pool_size_
]
->
enqueue
(
[
&
search_shards
,
idx
,
j
,
slice_num
,
&
shard_merge
]()
->
size_t
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
size_t
num
=
search_shards
[
idx
][
j
]
->
get_all_id
(
&
shard_keys
,
slice_num
);
// add to shard
shard_merge
.
merge
(
shard_keys
);
return
num
;
}));
}
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
++
i
)
{
tasks
[
i
].
wait
();
}
return
0
;
}
int
GraphTable
::
get_all_neighbor_id
(
int
type_id
,
int
slice_num
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
)
{
MergeShardVector
shard_merge
(
output
,
slice_num
);
auto
&
search_shards
=
type_id
==
0
?
edge_shards
:
feature_shards
;
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
for
(
int
idx
=
0
;
idx
<
search_shards
.
size
();
idx
++
)
{
for
(
int
j
=
0
;
j
<
search_shards
[
idx
].
size
();
j
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
j
%
task_pool_size_
]
->
enqueue
(
[
&
search_shards
,
idx
,
j
,
slice_num
,
&
shard_merge
]()
->
size_t
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
size_t
num
=
search_shards
[
idx
][
j
]
->
get_all_neighbor_id
(
&
shard_keys
,
slice_num
);
// add to shard
shard_merge
.
merge
(
shard_keys
);
return
num
;
}));
}
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
++
i
)
{
tasks
[
i
].
wait
();
}
return
0
;
}
int
GraphTable
::
get_all_id
(
int
type_id
,
int
idx
,
int
slice_num
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
)
{
MergeShardVector
shard_merge
(
output
,
slice_num
);
auto
&
search_shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
VLOG
(
3
)
<<
"begin task, task_pool_size_["
<<
task_pool_size_
<<
"]"
;
for
(
size_t
i
=
0
;
i
<
search_shards
.
size
();
i
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
&
search_shards
,
i
,
slice_num
,
&
shard_merge
]()
->
size_t
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
size_t
num
=
search_shards
[
i
]
->
get_all_id
(
&
shard_keys
,
slice_num
);
// add to shard
shard_merge
.
merge
(
shard_keys
);
return
num
;
}));
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
++
i
)
{
tasks
[
i
].
wait
();
}
VLOG
(
3
)
<<
"end task, task_pool_size_["
<<
task_pool_size_
<<
"]"
;
return
0
;
}
int
GraphTable
::
get_all_neighbor_id
(
int
type_id
,
int
idx
,
int
slice_num
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
)
{
MergeShardVector
shard_merge
(
output
,
slice_num
);
auto
&
search_shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
VLOG
(
3
)
<<
"begin task, task_pool_size_["
<<
task_pool_size_
<<
"]"
;
for
(
int
i
=
0
;
i
<
search_shards
.
size
();
i
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
&
search_shards
,
i
,
slice_num
,
&
shard_merge
]()
->
size_t
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
size_t
num
=
search_shards
[
i
]
->
get_all_neighbor_id
(
&
shard_keys
,
slice_num
);
// add to shard
shard_merge
.
merge
(
shard_keys
);
return
num
;
}));
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
++
i
)
{
tasks
[
i
].
wait
();
}
VLOG
(
3
)
<<
"end task, task_pool_size_["
<<
task_pool_size_
<<
"]"
;
return
0
;
}
int
GraphTable
::
get_all_feature_ids
(
int
type_id
,
int
idx
,
int
slice_num
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
)
{
MergeShardVector
shard_merge
(
output
,
slice_num
);
auto
&
search_shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
for
(
int
i
=
0
;
i
<
search_shards
.
size
();
i
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
&
search_shards
,
i
,
slice_num
,
&
shard_merge
]()
->
size_t
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
size_t
num
=
search_shards
[
i
]
->
get_all_feature_ids
(
&
shard_keys
,
slice_num
);
// add to shard
shard_merge
.
merge
(
shard_keys
);
return
num
;
}));
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
++
i
)
{
tasks
[
i
].
wait
();
}
return
0
;
}
int32_t
GraphTable
::
pull_graph_list
(
int
type_id
,
int
idx
,
int
start
,
int
total_size
,
std
::
unique_ptr
<
char
[]
>
&
buffer
,
int
&
actual_size
,
bool
need_feature
,
int
step
)
{
if
(
start
<
0
)
start
=
0
;
int
size
=
0
,
cur_size
;
auto
&
search_shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
std
::
vector
<
std
::
future
<
std
::
vector
<
Node
*>>>
tasks
;
for
(
size_t
i
=
0
;
i
<
search_shards
.
size
()
&&
total_size
>
0
;
i
++
)
{
cur_size
=
search_shards
[
i
]
->
get_size
();
if
(
size
+
cur_size
<=
start
)
{
size
+=
cur_size
;
continue
;
}
int
count
=
std
::
min
(
1
+
(
size
+
cur_size
-
start
-
1
)
/
step
,
total_size
);
int
end
=
start
+
(
count
-
1
)
*
step
+
1
;
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
&
search_shards
,
this
,
i
,
start
,
end
,
step
,
size
]()
->
std
::
vector
<
Node
*>
{
return
search_shards
[
i
]
->
get_batch
(
start
-
size
,
end
-
size
,
step
);
}));
start
+=
count
*
step
;
total_size
-=
count
;
size
+=
cur_size
;
}
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
++
i
)
{
tasks
[
i
].
wait
();
}
size
=
0
;
std
::
vector
<
std
::
vector
<
Node
*>>
res
;
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
{
res
.
push_back
(
tasks
[
i
].
get
());
for
(
size_t
j
=
0
;
j
<
res
.
back
().
size
();
j
++
)
{
size
+=
res
.
back
()[
j
]
->
get_size
(
need_feature
);
}
}
char
*
buffer_addr
=
new
char
[
size
];
buffer
.
reset
(
buffer_addr
);
int
index
=
0
;
for
(
size_t
i
=
0
;
i
<
res
.
size
();
i
++
)
{
for
(
size_t
j
=
0
;
j
<
res
[
i
].
size
();
j
++
)
{
res
[
i
][
j
]
->
to_buffer
(
buffer_addr
+
index
,
need_feature
);
index
+=
res
[
i
][
j
]
->
get_size
(
need_feature
);
}
}
actual_size
=
size
;
return
0
;
}
void
GraphTable
::
set_feature_separator
(
const
std
::
string
&
ch
)
{
feature_separator_
=
ch
;
}
int32_t
GraphTable
::
get_server_index_by_id
(
uint64_t
id
)
{
return
id
%
shard_num
/
shard_num_per_server
;
}
int32_t
GraphTable
::
Initialize
(
const
TableParameter
&
config
,
const
FsClientParameter
&
fs_config
)
{
LOG
(
INFO
)
<<
"in graphTable initialize"
;
_config
=
config
;
if
(
InitializeAccessor
()
!=
0
)
{
LOG
(
WARNING
)
<<
"Table accessor initialize failed"
;
return
-
1
;
}
if
(
_afs_client
.
initialize
(
fs_config
)
!=
0
)
{
LOG
(
WARNING
)
<<
"Table fs_client initialize failed"
;
// return -1;
}
auto
graph
=
config
.
graph_parameter
();
shard_num
=
_config
.
shard_num
();
LOG
(
INFO
)
<<
"in graphTable initialize over"
;
return
Initialize
(
graph
);
}
void
GraphTable
::
load_node_weight
(
int
type_id
,
int
idx
,
std
::
string
path
)
{
auto
paths
=
paddle
::
string
::
split_string
<
std
::
string
>
(
path
,
";"
);
int64_t
count
=
0
;
auto
&
weight_map
=
node_weight
[
type_id
][
idx
];
for
(
auto
path
:
paths
)
{
std
::
ifstream
file
(
path
);
std
::
string
line
;
while
(
std
::
getline
(
file
,
line
))
{
auto
values
=
paddle
::
string
::
split_string
<
std
::
string
>
(
line
,
"
\t
"
);
count
++
;
if
(
values
.
size
()
<
2
)
continue
;
auto
src_id
=
std
::
stoull
(
values
[
0
]);
double
weight
=
std
::
stod
(
values
[
1
]);
weight_map
[
src_id
]
=
weight
;
}
}
}
int32_t
GraphTable
::
Initialize
(
const
GraphParameter
&
graph
)
{
task_pool_size_
=
graph
.
task_pool_size
();
build_sampler_on_cpu
=
graph
.
build_sampler_on_cpu
();
#ifdef PADDLE_WITH_HETERPS
_db
=
NULL
;
search_level
=
graph
.
search_level
();
if
(
search_level
>=
2
)
{
_db
=
paddle
::
distributed
::
RocksDBHandler
::
GetInstance
();
_db
->
initialize
(
"./temp_gpups_db"
,
task_pool_size_
);
}
// gpups_mode = true;
// auto *sampler =
// CREATE_PSCORE_CLASS(GraphSampler, graph.gpups_graph_sample_class());
// auto slices =
// string::split_string<std::string>(graph.gpups_graph_sample_args(), ",");
// std::cout << "slices" << std::endl;
// for (auto x : slices) std::cout << x << std::endl;
// sampler->init(graph.gpu_num(), this, slices);
// graph_sampler.reset(sampler);
#endif
if
(
shard_num
==
0
)
{
server_num
=
1
;
_shard_idx
=
0
;
shard_num
=
graph
.
shard_num
();
}
use_cache
=
graph
.
use_cache
();
if
(
use_cache
)
{
cache_size_limit
=
graph
.
cache_size_limit
();
cache_ttl
=
graph
.
cache_ttl
();
make_neighbor_sample_cache
((
size_t
)
cache_size_limit
,
(
size_t
)
cache_ttl
);
}
_shards_task_pool
.
resize
(
task_pool_size_
);
for
(
size_t
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
_shards_task_pool
[
i
].
reset
(
new
::
ThreadPool
(
1
));
_shards_task_rng_pool
.
push_back
(
paddle
::
framework
::
GetCPURandomEngine
(
0
));
}
load_node_edge_task_pool
.
reset
(
new
::
ThreadPool
(
load_thread_num
));
auto
graph_feature
=
graph
.
graph_feature
();
auto
node_types
=
graph
.
node_types
();
auto
edge_types
=
graph
.
edge_types
();
VLOG
(
0
)
<<
"got "
<<
edge_types
.
size
()
<<
"edge types in total"
;
feat_id_map
.
resize
(
node_types
.
size
());
for
(
int
k
=
0
;
k
<
edge_types
.
size
();
k
++
)
{
VLOG
(
0
)
<<
"in initialize: get a edge_type "
<<
edge_types
[
k
];
edge_to_id
[
edge_types
[
k
]]
=
k
;
id_to_edge
.
push_back
(
edge_types
[
k
]);
}
feat_name
.
resize
(
node_types
.
size
());
feat_shape
.
resize
(
node_types
.
size
());
feat_dtype
.
resize
(
node_types
.
size
());
VLOG
(
0
)
<<
"got "
<<
node_types
.
size
()
<<
"node types in total"
;
for
(
int
k
=
0
;
k
<
node_types
.
size
();
k
++
)
{
feature_to_id
[
node_types
[
k
]]
=
k
;
auto
node_type
=
node_types
[
k
];
auto
feature
=
graph_feature
[
k
];
id_to_feature
.
push_back
(
node_type
);
int
feat_conf_size
=
static_cast
<
int
>
(
feature
.
name
().
size
());
for
(
int
i
=
0
;
i
<
feat_conf_size
;
i
++
)
{
// auto &f_name = common.attributes()[i];
// auto &f_shape = common.dims()[i];
// auto &f_dtype = common.params()[i];
auto
&
f_name
=
feature
.
name
()[
i
];
auto
&
f_shape
=
feature
.
shape
()[
i
];
auto
&
f_dtype
=
feature
.
dtype
()[
i
];
feat_name
[
k
].
push_back
(
f_name
);
feat_shape
[
k
].
push_back
(
f_shape
);
feat_dtype
[
k
].
push_back
(
f_dtype
);
feat_id_map
[
k
][
f_name
]
=
i
;
VLOG
(
0
)
<<
"init graph table feat conf name:"
<<
f_name
<<
" shape:"
<<
f_shape
<<
" dtype:"
<<
f_dtype
;
}
}
// this->table_name = common.table_name();
// this->table_type = common.name();
this
->
table_name
=
graph
.
table_name
();
this
->
table_type
=
graph
.
table_type
();
VLOG
(
0
)
<<
" init graph table type "
<<
this
->
table_type
<<
" table name "
<<
this
->
table_name
;
// int feat_conf_size = static_cast<int>(common.attributes().size());
// int feat_conf_size = static_cast<int>(graph_feature.name().size());
VLOG
(
0
)
<<
"in init graph table shard num = "
<<
shard_num
<<
" shard_idx"
<<
_shard_idx
;
shard_num_per_server
=
sparse_local_shard_num
(
shard_num
,
server_num
);
shard_start
=
_shard_idx
*
shard_num_per_server
;
shard_end
=
shard_start
+
shard_num_per_server
;
VLOG
(
0
)
<<
"in init graph table shard idx = "
<<
_shard_idx
<<
" shard_start "
<<
shard_start
<<
" shard_end "
<<
shard_end
;
edge_shards
.
resize
(
id_to_edge
.
size
());
node_weight
.
resize
(
2
);
node_weight
[
0
].
resize
(
id_to_edge
.
size
());
#ifdef PADDLE_WITH_HETERPS
partitions
.
resize
(
id_to_edge
.
size
());
#endif
for
(
int
k
=
0
;
k
<
(
int
)
edge_shards
.
size
();
k
++
)
{
for
(
size_t
i
=
0
;
i
<
shard_num_per_server
;
i
++
)
{
edge_shards
[
k
].
push_back
(
new
GraphShard
());
}
}
node_weight
[
1
].
resize
(
id_to_feature
.
size
());
feature_shards
.
resize
(
id_to_feature
.
size
());
for
(
int
k
=
0
;
k
<
(
int
)
feature_shards
.
size
();
k
++
)
{
for
(
size_t
i
=
0
;
i
<
shard_num_per_server
;
i
++
)
{
feature_shards
[
k
].
push_back
(
new
GraphShard
());
}
}
return
0
;
}
}
// namespace distributed
};
// namespace paddle
paddle/fluid/distributed/ps/table/common_graph_table.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <ctime>
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <numeric>
#include <queue>
#include <set>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
namespace
paddle
{
namespace
distributed
{
class
GraphShard
{
public:
size_t
get_size
();
GraphShard
()
{}
~
GraphShard
();
std
::
vector
<
Node
*>
&
get_bucket
()
{
return
bucket
;
}
std
::
vector
<
Node
*>
get_batch
(
int
start
,
int
end
,
int
step
);
void
get_ids_by_range
(
int
start
,
int
end
,
std
::
vector
<
uint64_t
>
*
res
)
{
res
->
reserve
(
res
->
size
()
+
end
-
start
);
for
(
int
i
=
start
;
i
<
end
&&
i
<
(
int
)
bucket
.
size
();
i
++
)
{
res
->
emplace_back
(
bucket
[
i
]
->
get_id
());
}
}
size_t
get_all_id
(
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
shard_keys
,
int
slice_num
)
{
int
bucket_num
=
bucket
.
size
();
shard_keys
->
resize
(
slice_num
);
for
(
int
i
=
0
;
i
<
slice_num
;
++
i
)
{
(
*
shard_keys
)[
i
].
reserve
(
bucket_num
/
slice_num
);
}
for
(
int
i
=
0
;
i
<
bucket_num
;
i
++
)
{
uint64_t
k
=
bucket
[
i
]
->
get_id
();
(
*
shard_keys
)[
k
%
slice_num
].
emplace_back
(
k
);
}
return
bucket_num
;
}
size_t
get_all_neighbor_id
(
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
total_res
,
int
slice_num
)
{
std
::
vector
<
uint64_t
>
keys
;
for
(
size_t
i
=
0
;
i
<
bucket
.
size
();
i
++
)
{
size_t
neighbor_size
=
bucket
[
i
]
->
get_neighbor_size
();
size_t
n
=
keys
.
size
();
keys
.
resize
(
n
+
neighbor_size
);
for
(
size_t
j
=
0
;
j
<
neighbor_size
;
j
++
)
{
keys
[
n
+
j
]
=
bucket
[
i
]
->
get_neighbor_id
(
j
);
}
}
return
dedup2shard_keys
(
&
keys
,
total_res
,
slice_num
);
}
size_t
get_all_feature_ids
(
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
total_res
,
int
slice_num
)
{
std
::
vector
<
uint64_t
>
keys
;
for
(
int
i
=
0
;
i
<
(
int
)
bucket
.
size
();
i
++
)
{
bucket
[
i
]
->
get_feature_ids
(
&
keys
);
}
return
dedup2shard_keys
(
&
keys
,
total_res
,
slice_num
);
}
size_t
dedup2shard_keys
(
std
::
vector
<
uint64_t
>
*
keys
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
total_res
,
int
slice_num
)
{
size_t
num
=
keys
->
size
();
uint64_t
last_key
=
0
;
// sort key insert to vector
std
::
sort
(
keys
->
begin
(),
keys
->
end
());
total_res
->
resize
(
slice_num
);
for
(
int
shard_id
=
0
;
shard_id
<
slice_num
;
++
shard_id
)
{
(
*
total_res
)[
shard_id
].
reserve
(
num
/
slice_num
);
}
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
const
uint64_t
&
k
=
(
*
keys
)[
i
];
if
(
i
>
0
&&
last_key
==
k
)
{
continue
;
}
last_key
=
k
;
(
*
total_res
)[
k
%
slice_num
].
push_back
(
k
);
}
return
num
;
}
GraphNode
*
add_graph_node
(
uint64_t
id
);
GraphNode
*
add_graph_node
(
Node
*
node
);
FeatureNode
*
add_feature_node
(
uint64_t
id
,
bool
is_overlap
=
true
);
Node
*
find_node
(
uint64_t
id
);
void
delete_node
(
uint64_t
id
);
void
clear
();
void
add_neighbor
(
uint64_t
id
,
uint64_t
dst_id
,
float
weight
);
std
::
unordered_map
<
uint64_t
,
int
>
&
get_node_location
()
{
return
node_location
;
}
private:
std
::
unordered_map
<
uint64_t
,
int
>
node_location
;
std
::
vector
<
Node
*>
bucket
;
};
enum
LRUResponse
{
ok
=
0
,
blocked
=
1
,
err
=
2
};
struct
SampleKey
{
int
idx
;
uint64_t
node_key
;
size_t
sample_size
;
bool
is_weighted
;
SampleKey
(
int
_idx
,
uint64_t
_node_key
,
size_t
_sample_size
,
bool
_is_weighted
)
{
idx
=
_idx
;
node_key
=
_node_key
;
sample_size
=
_sample_size
;
is_weighted
=
_is_weighted
;
}
bool
operator
==
(
const
SampleKey
&
s
)
const
{
return
idx
==
s
.
idx
&&
node_key
==
s
.
node_key
&&
sample_size
==
s
.
sample_size
&&
is_weighted
==
s
.
is_weighted
;
}
};
class
SampleResult
{
public:
size_t
actual_size
;
std
::
shared_ptr
<
char
>
buffer
;
SampleResult
(
size_t
_actual_size
,
std
::
shared_ptr
<
char
>
&
_buffer
)
:
actual_size
(
_actual_size
),
buffer
(
_buffer
)
{}
SampleResult
(
size_t
_actual_size
,
char
*
_buffer
)
:
actual_size
(
_actual_size
),
buffer
(
_buffer
,
[](
char
*
p
)
{
delete
[]
p
;
})
{}
~
SampleResult
()
{}
};
template
<
typename
K
,
typename
V
>
class
LRUNode
{
public:
LRUNode
(
K
_key
,
V
_data
,
size_t
_ttl
)
:
key
(
_key
),
data
(
_data
),
ttl
(
_ttl
)
{
next
=
pre
=
NULL
;
}
K
key
;
V
data
;
size_t
ttl
;
// time to live
LRUNode
<
K
,
V
>
*
pre
,
*
next
;
};
template
<
typename
K
,
typename
V
>
class
ScaledLRU
;
template
<
typename
K
,
typename
V
>
class
RandomSampleLRU
{
public:
RandomSampleLRU
(
ScaledLRU
<
K
,
V
>
*
_father
)
{
father
=
_father
;
remove_count
=
0
;
node_size
=
0
;
node_head
=
node_end
=
NULL
;
global_ttl
=
father
->
ttl
;
total_diff
=
0
;
}
~
RandomSampleLRU
()
{
LRUNode
<
K
,
V
>
*
p
;
while
(
node_head
!=
NULL
)
{
p
=
node_head
->
next
;
delete
node_head
;
node_head
=
p
;
}
}
LRUResponse
query
(
K
*
keys
,
size_t
length
,
std
::
vector
<
std
::
pair
<
K
,
V
>>
&
res
)
{
if
(
pthread_rwlock_tryrdlock
(
&
father
->
rwlock
)
!=
0
)
return
LRUResponse
::
blocked
;
// pthread_rwlock_rdlock(&father->rwlock);
int
init_size
=
node_size
-
remove_count
;
process_redundant
(
length
*
3
);
for
(
size_t
i
=
0
;
i
<
length
;
i
++
)
{
auto
iter
=
key_map
.
find
(
keys
[
i
]);
if
(
iter
!=
key_map
.
end
())
{
res
.
emplace_back
(
keys
[
i
],
iter
->
second
->
data
);
iter
->
second
->
ttl
--
;
if
(
iter
->
second
->
ttl
==
0
)
{
remove
(
iter
->
second
);
if
(
remove_count
!=
0
)
remove_count
--
;
}
else
{
move_to_tail
(
iter
->
second
);
}
}
}
total_diff
+=
node_size
-
remove_count
-
init_size
;
if
(
total_diff
>=
500
||
total_diff
<
-
500
)
{
father
->
handle_size_diff
(
total_diff
);
total_diff
=
0
;
}
pthread_rwlock_unlock
(
&
father
->
rwlock
);
return
LRUResponse
::
ok
;
}
LRUResponse
insert
(
K
*
keys
,
V
*
data
,
size_t
length
)
{
if
(
pthread_rwlock_tryrdlock
(
&
father
->
rwlock
)
!=
0
)
return
LRUResponse
::
blocked
;
// pthread_rwlock_rdlock(&father->rwlock);
int
init_size
=
node_size
-
remove_count
;
process_redundant
(
length
*
3
);
for
(
size_t
i
=
0
;
i
<
length
;
i
++
)
{
auto
iter
=
key_map
.
find
(
keys
[
i
]);
if
(
iter
!=
key_map
.
end
())
{
move_to_tail
(
iter
->
second
);
iter
->
second
->
ttl
=
global_ttl
;
iter
->
second
->
data
=
data
[
i
];
}
else
{
LRUNode
<
K
,
V
>
*
temp
=
new
LRUNode
<
K
,
V
>
(
keys
[
i
],
data
[
i
],
global_ttl
);
add_new
(
temp
);
}
}
total_diff
+=
node_size
-
remove_count
-
init_size
;
if
(
total_diff
>=
500
||
total_diff
<
-
500
)
{
father
->
handle_size_diff
(
total_diff
);
total_diff
=
0
;
}
pthread_rwlock_unlock
(
&
father
->
rwlock
);
return
LRUResponse
::
ok
;
}
void
remove
(
LRUNode
<
K
,
V
>
*
node
)
{
fetch
(
node
);
node_size
--
;
key_map
.
erase
(
node
->
key
);
delete
node
;
}
void
process_redundant
(
int
process_size
)
{
int
length
=
std
::
min
(
remove_count
,
process_size
);
while
(
length
--
)
{
remove
(
node_head
);
remove_count
--
;
}
// std::cerr<<"after remove_count = "<<remove_count<<std::endl;
}
void
move_to_tail
(
LRUNode
<
K
,
V
>
*
node
)
{
fetch
(
node
);
place_at_tail
(
node
);
}
void
add_new
(
LRUNode
<
K
,
V
>
*
node
)
{
node
->
ttl
=
global_ttl
;
place_at_tail
(
node
);
node_size
++
;
key_map
[
node
->
key
]
=
node
;
}
void
place_at_tail
(
LRUNode
<
K
,
V
>
*
node
)
{
if
(
node_end
==
NULL
)
{
node_head
=
node_end
=
node
;
node
->
next
=
node
->
pre
=
NULL
;
}
else
{
node_end
->
next
=
node
;
node
->
pre
=
node_end
;
node
->
next
=
NULL
;
node_end
=
node
;
}
}
void
fetch
(
LRUNode
<
K
,
V
>
*
node
)
{
if
(
node
->
pre
)
{
node
->
pre
->
next
=
node
->
next
;
}
else
{
node_head
=
node
->
next
;
}
if
(
node
->
next
)
{
node
->
next
->
pre
=
node
->
pre
;
}
else
{
node_end
=
node
->
pre
;
}
}
private:
std
::
unordered_map
<
K
,
LRUNode
<
K
,
V
>
*>
key_map
;
ScaledLRU
<
K
,
V
>
*
father
;
size_t
global_ttl
,
size_limit
;
int
node_size
,
total_diff
;
LRUNode
<
K
,
V
>
*
node_head
,
*
node_end
;
friend
class
ScaledLRU
<
K
,
V
>
;
int
remove_count
;
};
template
<
typename
K
,
typename
V
>
class
ScaledLRU
{
public:
ScaledLRU
(
size_t
_shard_num
,
size_t
size_limit
,
size_t
_ttl
)
:
size_limit
(
size_limit
),
ttl
(
_ttl
)
{
shard_num
=
_shard_num
;
pthread_rwlock_init
(
&
rwlock
,
NULL
);
stop
=
false
;
thread_pool
.
reset
(
new
::
ThreadPool
(
1
));
global_count
=
0
;
lru_pool
=
std
::
vector
<
RandomSampleLRU
<
K
,
V
>>
(
shard_num
,
RandomSampleLRU
<
K
,
V
>
(
this
));
shrink_job
=
std
::
thread
([
this
]()
->
void
{
while
(
true
)
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait_for
(
lock
,
std
::
chrono
::
milliseconds
(
20000
));
if
(
stop
)
{
return
;
}
}
auto
status
=
thread_pool
->
enqueue
([
this
]()
->
int
{
return
Shrink
();
});
status
.
wait
();
}
});
shrink_job
.
detach
();
}
~
ScaledLRU
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
stop
=
true
;
cv_
.
notify_one
();
}
LRUResponse
query
(
size_t
index
,
K
*
keys
,
size_t
length
,
std
::
vector
<
std
::
pair
<
K
,
V
>>
&
res
)
{
return
lru_pool
[
index
].
query
(
keys
,
length
,
res
);
}
LRUResponse
insert
(
size_t
index
,
K
*
keys
,
V
*
data
,
size_t
length
)
{
return
lru_pool
[
index
].
insert
(
keys
,
data
,
length
);
}
int
Shrink
()
{
int
node_size
=
0
;
for
(
size_t
i
=
0
;
i
<
lru_pool
.
size
();
i
++
)
{
node_size
+=
lru_pool
[
i
].
node_size
-
lru_pool
[
i
].
remove_count
;
}
if
((
size_t
)
node_size
<=
size_t
(
1.1
*
size_limit
)
+
1
)
return
0
;
if
(
pthread_rwlock_wrlock
(
&
rwlock
)
==
0
)
{
global_count
=
0
;
for
(
size_t
i
=
0
;
i
<
lru_pool
.
size
();
i
++
)
{
global_count
+=
lru_pool
[
i
].
node_size
-
lru_pool
[
i
].
remove_count
;
}
if
((
size_t
)
global_count
>
size_limit
)
{
size_t
remove
=
global_count
-
size_limit
;
for
(
size_t
i
=
0
;
i
<
lru_pool
.
size
();
i
++
)
{
lru_pool
[
i
].
total_diff
=
0
;
lru_pool
[
i
].
remove_count
+=
1.0
*
(
lru_pool
[
i
].
node_size
-
lru_pool
[
i
].
remove_count
)
/
global_count
*
remove
;
}
}
pthread_rwlock_unlock
(
&
rwlock
);
return
0
;
}
return
0
;
}
void
handle_size_diff
(
int
diff
)
{
if
(
diff
!=
0
)
{
__sync_fetch_and_add
(
&
global_count
,
diff
);
if
(
global_count
>
int
(
1.25
*
size_limit
))
{
thread_pool
->
enqueue
([
this
]()
->
int
{
return
Shrink
();
});
}
}
}
size_t
get_ttl
()
{
return
ttl
;
}
private:
pthread_rwlock_t
rwlock
;
size_t
shard_num
;
int
global_count
;
size_t
size_limit
,
total
,
hit
;
size_t
ttl
;
bool
stop
;
std
::
thread
shrink_job
;
std
::
vector
<
RandomSampleLRU
<
K
,
V
>>
lru_pool
;
mutable
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
std
::
shared_ptr
<::
ThreadPool
>
thread_pool
;
friend
class
RandomSampleLRU
<
K
,
V
>
;
};
/*
#ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable;
class GraphSampler {
public:
GraphSampler() {
status = GraphSamplerStatus::waiting;
thread_pool.reset(new ::ThreadPool(1));
callback = [](std::vector<paddle::framework::GpuPsCommGraph> &res) {
return;
};
}
virtual int loadData(const std::string &path){
return 0;
}
virtual int run_graph_sampling() = 0;
virtual int start_graph_sampling() {
if (status != GraphSamplerStatus::waiting) {
return -1;
}
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
prom.set_value(0);
status = GraphSamplerStatus::running;
return run_graph_sampling();
});
return fut.get();
}
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) = 0;
virtual void set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
this->callback = callback;
}
virtual int end_graph_sampling() {
if (status == GraphSamplerStatus::running) {
status = GraphSamplerStatus::terminating;
return graph_sample_task_over.get();
}
return -1;
}
virtual GraphSamplerStatus get_graph_sampler_status() { return status; }
protected:
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback;
std::shared_ptr<::ThreadPool> thread_pool;
GraphSamplerStatus status;
std::future<int> graph_sample_task_over;
std::vector<paddle::framework::GpuPsCommGraph> sample_res;
};
#endif
*/
class
GraphTable
:
public
Table
{
public:
GraphTable
()
{
use_cache
=
false
;
shard_num
=
0
;
rw_lock
.
reset
(
new
pthread_rwlock_t
());
#ifdef PADDLE_WITH_HETERPS
next_partition
=
0
;
total_memory_cost
=
0
;
#endif
}
virtual
~
GraphTable
();
virtual
void
*
GetShard
(
size_t
shard_idx
)
{
return
0
;
}
static
int32_t
sparse_local_shard_num
(
uint32_t
shard_num
,
uint32_t
server_num
)
{
if
(
shard_num
%
server_num
==
0
)
{
return
shard_num
/
server_num
;
}
size_t
local_shard_num
=
shard_num
/
server_num
+
1
;
return
local_shard_num
;
}
static
size_t
get_sparse_shard
(
uint32_t
shard_num
,
uint32_t
server_num
,
uint64_t
key
)
{
return
(
key
%
shard_num
)
/
sparse_local_shard_num
(
shard_num
,
server_num
);
}
virtual
int32_t
pull_graph_list
(
int
type_id
,
int
idx
,
int
start
,
int
size
,
std
::
unique_ptr
<
char
[]
>
&
buffer
,
int
&
actual_size
,
bool
need_feature
,
int
step
);
virtual
int32_t
random_sample_neighbors
(
int
idx
,
uint64_t
*
node_ids
,
int
sample_size
,
std
::
vector
<
std
::
shared_ptr
<
char
>>
&
buffers
,
std
::
vector
<
int
>
&
actual_sizes
,
bool
need_weight
);
int32_t
random_sample_nodes
(
int
type_id
,
int
idx
,
int
sample_size
,
std
::
unique_ptr
<
char
[]
>
&
buffers
,
int
&
actual_sizes
);
virtual
int32_t
get_nodes_ids_by_ranges
(
int
type_id
,
int
idx
,
std
::
vector
<
std
::
pair
<
int
,
int
>>
ranges
,
std
::
vector
<
uint64_t
>
&
res
);
virtual
int32_t
Initialize
()
{
return
0
;
}
virtual
int32_t
Initialize
(
const
TableParameter
&
config
,
const
FsClientParameter
&
fs_config
);
virtual
int32_t
Initialize
(
const
GraphParameter
&
config
);
int32_t
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
);
int32_t
load_node_and_edge_file
(
std
::
string
etype
,
std
::
string
ntype
,
std
::
string
epath
,
std
::
string
npath
,
int
part_num
,
bool
reverse
);
std
::
string
get_inverse_etype
(
std
::
string
&
etype
);
int32_t
load_edges
(
const
std
::
string
&
path
,
bool
reverse
,
const
std
::
string
&
edge_type
);
int
get_all_id
(
int
type
,
int
slice_num
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
);
int
get_all_neighbor_id
(
int
type
,
int
slice_num
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
);
int
get_all_id
(
int
type
,
int
idx
,
int
slice_num
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
);
int
get_all_neighbor_id
(
int
type_id
,
int
id
,
int
slice_num
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
);
int
get_all_feature_ids
(
int
type
,
int
idx
,
int
slice_num
,
std
::
vector
<
std
::
vector
<
uint64_t
>>
*
output
);
int32_t
load_nodes
(
const
std
::
string
&
path
,
std
::
string
node_type
=
std
::
string
());
std
::
pair
<
uint64_t
,
uint64_t
>
parse_edge_file
(
const
std
::
string
&
path
,
int
idx
,
bool
reverse
);
std
::
pair
<
uint64_t
,
uint64_t
>
parse_node_file
(
const
std
::
string
&
path
,
const
std
::
string
&
node_type
,
int
idx
);
std
::
pair
<
uint64_t
,
uint64_t
>
parse_node_file
(
const
std
::
string
&
path
);
int32_t
add_graph_node
(
int
idx
,
std
::
vector
<
uint64_t
>
&
id_list
,
std
::
vector
<
bool
>
&
is_weight_list
);
int32_t
remove_graph_node
(
int
idx
,
std
::
vector
<
uint64_t
>
&
id_list
);
int32_t
get_server_index_by_id
(
uint64_t
id
);
Node
*
find_node
(
int
type_id
,
int
idx
,
uint64_t
id
);
Node
*
find_node
(
int
type_id
,
uint64_t
id
);
virtual
int32_t
Pull
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
Push
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
clear_nodes
(
int
type
,
int
idx
);
virtual
void
Clear
()
{}
virtual
int32_t
Flush
()
{
return
0
;
}
virtual
int32_t
Shrink
(
const
std
::
string
&
param
)
{
return
0
;
}
//指定保存路径
virtual
int32_t
Save
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
{
return
0
;
}
virtual
int32_t
InitializeShard
()
{
return
0
;
}
virtual
int32_t
SetShard
(
size_t
shard_idx
,
size_t
server_num
)
{
_shard_idx
=
shard_idx
;
/*
_shard_num is not used in graph_table, this following operation is for the
purpose of
being compatible with base class table.
*/
_shard_num
=
server_num
;
this
->
server_num
=
server_num
;
return
0
;
}
virtual
uint32_t
get_thread_pool_index_by_shard_index
(
uint64_t
shard_index
);
virtual
uint32_t
get_thread_pool_index
(
uint64_t
node_id
);
virtual
int
parse_feature
(
int
idx
,
const
char
*
feat_str
,
size_t
len
,
FeatureNode
*
node
);
virtual
int32_t
get_node_feat
(
int
idx
,
const
std
::
vector
<
uint64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
);
virtual
int32_t
set_node_feat
(
int
idx
,
const
std
::
vector
<
uint64_t
>
&
node_ids
,
const
std
::
vector
<
std
::
string
>
&
feature_names
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
res
);
size_t
get_server_num
()
{
return
server_num
;
}
void
clear_graph
(
int
idx
);
virtual
int32_t
make_neighbor_sample_cache
(
size_t
size_limit
,
size_t
ttl
)
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
use_cache
==
false
)
{
scaled_lru
.
reset
(
new
ScaledLRU
<
SampleKey
,
SampleResult
>
(
task_pool_size_
,
size_limit
,
ttl
));
use_cache
=
true
;
}
}
return
0
;
}
virtual
void
load_node_weight
(
int
type_id
,
int
idx
,
std
::
string
path
);
#ifdef PADDLE_WITH_HETERPS
// virtual int32_t start_graph_sampling() {
// return this->graph_sampler->start_graph_sampling();
// }
// virtual int32_t end_graph_sampling() {
// return this->graph_sampler->end_graph_sampling();
// }
// virtual int32_t set_graph_sample_callback(
// std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
// callback) {
// graph_sampler->set_graph_sample_callback(callback);
// return 0;
// }
virtual
void
make_partitions
(
int
idx
,
int64_t
gb_size
,
int
device_len
);
virtual
void
export_partition_files
(
int
idx
,
std
::
string
file_path
);
virtual
char
*
random_sample_neighbor_from_ssd
(
int
idx
,
uint64_t
id
,
int
sample_size
,
const
std
::
shared_ptr
<
std
::
mt19937_64
>
rng
,
int
&
actual_size
);
virtual
int32_t
add_node_to_ssd
(
int
type_id
,
int
idx
,
uint64_t
src_id
,
char
*
data
,
int
len
);
virtual
paddle
::
framework
::
GpuPsCommGraph
make_gpu_ps_graph
(
int
idx
,
std
::
vector
<
uint64_t
>
ids
);
virtual
paddle
::
framework
::
GpuPsCommGraphFea
make_gpu_ps_graph_fea
(
std
::
vector
<
uint64_t
>
&
node_ids
,
int
slot_num
);
int32_t
Load_to_ssd
(
const
std
::
string
&
path
,
const
std
::
string
&
param
);
int64_t
load_graph_to_memory_from_ssd
(
int
idx
,
std
::
vector
<
uint64_t
>
&
ids
);
int32_t
make_complementary_graph
(
int
idx
,
int64_t
byte_size
);
int32_t
dump_edges_to_ssd
(
int
idx
);
int32_t
get_partition_num
(
int
idx
)
{
return
partitions
[
idx
].
size
();
}
std
::
vector
<
uint64_t
>
get_partition
(
int
idx
,
int
index
)
{
if
(
idx
>=
(
int
)
partitions
.
size
()
||
index
>=
(
int
)
partitions
[
idx
].
size
())
return
std
::
vector
<
uint64_t
>
();
return
partitions
[
idx
][
index
];
}
int32_t
load_edges_to_ssd
(
const
std
::
string
&
path
,
bool
reverse_edge
,
const
std
::
string
&
edge_type
);
int32_t
load_next_partition
(
int
idx
);
void
set_search_level
(
int
search_level
)
{
this
->
search_level
=
search_level
;
}
int
search_level
;
int64_t
total_memory_cost
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
uint64_t
>>>
partitions
;
int
next_partition
;
#endif
virtual
int32_t
add_comm_edge
(
int
idx
,
uint64_t
src_id
,
uint64_t
dst_id
);
virtual
int32_t
build_sampler
(
int
idx
,
std
::
string
sample_type
=
"random"
);
void
set_feature_separator
(
const
std
::
string
&
ch
);
std
::
vector
<
std
::
vector
<
GraphShard
*>>
edge_shards
,
feature_shards
;
size_t
shard_start
,
shard_end
,
server_num
,
shard_num_per_server
,
shard_num
;
int
task_pool_size_
=
24
;
int
load_thread_num
=
160
;
const
int
random_sample_nodes_ranges
=
3
;
std
::
vector
<
std
::
vector
<
std
::
unordered_map
<
uint64_t
,
double
>>>
node_weight
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
feat_name
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
feat_dtype
;
std
::
vector
<
std
::
vector
<
int32_t
>>
feat_shape
;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
int32_t
>>
feat_id_map
;
std
::
unordered_map
<
std
::
string
,
int
>
feature_to_id
,
edge_to_id
;
std
::
vector
<
std
::
string
>
id_to_feature
,
id_to_edge
;
std
::
string
table_name
;
std
::
string
table_type
;
std
::
vector
<
std
::
shared_ptr
<::
ThreadPool
>>
_shards_task_pool
;
std
::
vector
<
std
::
shared_ptr
<
std
::
mt19937_64
>>
_shards_task_rng_pool
;
std
::
shared_ptr
<::
ThreadPool
>
load_node_edge_task_pool
;
std
::
shared_ptr
<
ScaledLRU
<
SampleKey
,
SampleResult
>>
scaled_lru
;
std
::
unordered_set
<
uint64_t
>
extra_nodes
;
std
::
unordered_map
<
uint64_t
,
size_t
>
extra_nodes_to_thread_index
;
bool
use_cache
,
use_duplicate_nodes
;
int
cache_size_limit
;
int
cache_ttl
;
mutable
std
::
mutex
mutex_
;
bool
build_sampler_on_cpu
;
std
::
shared_ptr
<
pthread_rwlock_t
>
rw_lock
;
#ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table;
paddle
::
distributed
::
RocksDBHandler
*
_db
;
// std::shared_ptr<::ThreadPool> graph_sample_pool;
// std::shared_ptr<GraphSampler> graph_sampler;
// REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
#endif
std
::
string
feature_separator_
=
std
::
string
(
" "
);
};
/*
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler {
public:
CompleteGraphSampler() {}
~CompleteGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<uint64_t>> sample_neighbors;
// std::vector<GpuPsCommGraph> sample_res;
// std::shared_ptr<std::mt19937_64> random;
int gpu_num;
};
class BasicBfsGraphSampler : public GraphSampler {
public:
BasicBfsGraphSampler() {}
~BasicBfsGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
// std::vector<std::vector<GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<uint64_t>> sample_neighbors;
size_t gpu_num;
int init_search_size, node_num_for_each_shard, edge_num_for_each_node;
int rounds, interval;
std::vector<std::unordered_map<uint64_t, std::vector<uint64_t>>>
sample_neighbors_map;
};
#endif
*/
}
// namespace distributed
};
// namespace paddle
namespace
std
{
template
<
>
struct
hash
<
paddle
::
distributed
::
SampleKey
>
{
size_t
operator
()(
const
paddle
::
distributed
::
SampleKey
&
s
)
const
{
return
s
.
idx
^
s
.
node_key
^
s
.
sample_size
;
}
};
}
// namespace std
paddle/fluid/distributed/ps/table/common_table.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <condition_variable> // NOLINT
#include <mutex> // NOLINT
#include <set>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace
paddle
{
namespace
distributed
{
template
<
typename
T
>
struct
ReservoirValue
{
std
::
vector
<
T
>
values
;
uint32_t
counter
;
uint32_t
dim
;
ReservoirValue
()
{
dim
=
0
;
values
.
resize
(
dim
);
counter
=
0
;
}
ReservoirValue
(
uint32_t
dim
)
{
this
->
dim
=
dim
;
values
.
resize
(
dim
);
counter
=
0
;
}
void
add
(
const
T
*
value
,
int
numel
)
{
GetBlas
<
T
>
().
VADD
(
numel
,
values
.
data
(),
value
,
values
.
data
());
counter
++
;
}
void
add
(
T
*
value
,
int
numel
)
{
GetBlas
<
T
>
().
VADD
(
numel
,
values
.
data
(),
value
,
values
.
data
());
counter
++
;
}
void
avg
()
{
if
(
counter
==
0
)
return
;
auto
scale
=
1
/
static_cast
<
T
>
(
counter
);
GetBlas
<
T
>
().
SCAL
(
values
.
size
(),
scale
,
values
.
data
());
}
void
reset
()
{
std
::
fill
(
values
.
begin
(),
values
.
end
(),
0
);
counter
=
0
;
}
};
class
BarrierTable
:
public
Table
{
public:
BarrierTable
()
{}
virtual
~
BarrierTable
()
{}
virtual
void
*
GetShard
(
size_t
shard_idx
)
{
return
0
;
}
virtual
int32_t
Pull
(
TableContext
&
context
)
{
return
0
;
}
virtual
int32_t
Push
(
TableContext
&
context
)
{
return
0
;
}
int32_t
Shrink
(
const
std
::
string
&
param
)
override
{
return
0
;
}
virtual
void
Clear
()
{}
virtual
int32_t
Flush
()
{
return
0
;
}
virtual
int32_t
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
return
0
;
}
virtual
int32_t
Save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
return
0
;
}
virtual
int32_t
InitializeShard
()
{
return
0
;
}
virtual
int32_t
Initialize
()
override
;
// only for barrier
// 0: send_barrier 1: recv_barrier 2: complete
virtual
int32_t
Barrier
(
const
uint32_t
trainer_id
,
const
std
::
string
barrier_type
)
override
;
virtual
int32_t
SetTableMap
(
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>
*
table_map
)
override
;
private:
std
::
mutex
mutex_
;
std
::
condition_variable
trainer_wait_
;
std
::
set
<
uint64_t
>
trainer_ids_
;
std
::
set
<
uint64_t
>
trainer_all_
;
std
::
atomic
<
int
>
trigger_
;
std
::
atomic
<
bool
>
exit_
;
std
::
unordered_map
<
uint32_t
,
std
::
shared_ptr
<
Table
>>
*
table_map_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/table/ctr_accessor.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/ctr_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
distributed
{
int
CtrCommonAccessor
::
Initialize
()
{
auto
name
=
_config
.
embed_sgd_param
().
name
();
_embed_sgd_rule
=
CREATE_PSCORE_CLASS
(
SparseValueSGDRule
,
name
);
_embed_sgd_rule
->
LoadConfig
(
_config
.
embed_sgd_param
(),
1
);
name
=
_config
.
embedx_sgd_param
().
name
();
_embedx_sgd_rule
=
CREATE_PSCORE_CLASS
(
SparseValueSGDRule
,
name
);
_embedx_sgd_rule
->
LoadConfig
(
_config
.
embedx_sgd_param
(),
_config
.
embedx_dim
());
common_feature_value
.
embed_sgd_dim
=
_embed_sgd_rule
->
Dim
();
common_feature_value
.
embedx_dim
=
_config
.
embedx_dim
();
common_feature_value
.
embedx_sgd_dim
=
_embedx_sgd_rule
->
Dim
();
_show_click_decay_rate
=
_config
.
ctr_accessor_param
().
show_click_decay_rate
();
_ssd_unseenday_threshold
=
_config
.
ctr_accessor_param
().
ssd_unseenday_threshold
();
if
(
_config
.
ctr_accessor_param
().
show_scale
())
{
_show_scale
=
true
;
}
InitAccessorInfo
();
return
0
;
}
void
CtrCommonAccessor
::
InitAccessorInfo
()
{
_accessor_info
.
dim
=
common_feature_value
.
Dim
();
_accessor_info
.
size
=
common_feature_value
.
Size
();
auto
embedx_dim
=
_config
.
embedx_dim
();
_accessor_info
.
select_dim
=
3
+
embedx_dim
;
_accessor_info
.
select_size
=
_accessor_info
.
select_dim
*
sizeof
(
float
);
_accessor_info
.
update_dim
=
4
+
embedx_dim
;
_accessor_info
.
update_size
=
_accessor_info
.
update_dim
*
sizeof
(
float
);
_accessor_info
.
mf_size
=
(
embedx_dim
+
common_feature_value
.
embedx_sgd_dim
)
*
sizeof
(
float
);
}
bool
CtrCommonAccessor
::
Shrink
(
float
*
value
)
{
auto
delete_after_unseen_days
=
_config
.
ctr_accessor_param
().
delete_after_unseen_days
();
auto
delete_threshold
=
_config
.
ctr_accessor_param
().
delete_threshold
();
// time_decay first
common_feature_value
.
Show
(
value
)
*=
_show_click_decay_rate
;
common_feature_value
.
Click
(
value
)
*=
_show_click_decay_rate
;
// shrink after
auto
score
=
ShowClickScore
(
common_feature_value
.
Show
(
value
),
common_feature_value
.
Click
(
value
));
auto
unseen_days
=
common_feature_value
.
UnseenDays
(
value
);
if
(
score
<
delete_threshold
||
unseen_days
>
delete_after_unseen_days
)
{
return
true
;
}
return
false
;
}
bool
CtrCommonAccessor
::
SaveCache
(
float
*
value
,
int
param
,
double
global_cache_threshold
)
{
auto
base_threshold
=
_config
.
ctr_accessor_param
().
base_threshold
();
auto
delta_keep_days
=
_config
.
ctr_accessor_param
().
delta_keep_days
();
if
(
ShowClickScore
(
common_feature_value
.
Show
(
value
),
common_feature_value
.
Click
(
value
))
>=
base_threshold
&&
common_feature_value
.
UnseenDays
(
value
)
<=
delta_keep_days
)
{
return
common_feature_value
.
Show
(
value
)
>
global_cache_threshold
;
}
return
false
;
}
bool
CtrCommonAccessor
::
SaveSSD
(
float
*
value
)
{
if
(
common_feature_value
.
UnseenDays
(
value
)
>
_ssd_unseenday_threshold
)
{
return
true
;
}
return
false
;
}
bool
CtrCommonAccessor
::
Save
(
float
*
value
,
int
param
)
{
auto
base_threshold
=
_config
.
ctr_accessor_param
().
base_threshold
();
auto
delta_threshold
=
_config
.
ctr_accessor_param
().
delta_threshold
();
auto
delta_keep_days
=
_config
.
ctr_accessor_param
().
delta_keep_days
();
if
(
param
==
2
)
{
delta_threshold
=
0
;
}
switch
(
param
)
{
// save all
case
0
:
{
return
true
;
}
// save xbox delta
case
1
:
// save xbox base
case
2
:
{
if
(
ShowClickScore
(
common_feature_value
.
Show
(
value
),
common_feature_value
.
Click
(
value
))
>=
base_threshold
&&
common_feature_value
.
DeltaScore
(
value
)
>=
delta_threshold
&&
common_feature_value
.
UnseenDays
(
value
)
<=
delta_keep_days
)
{
// do this after save, because it must not be modified when retry
if
(
param
==
2
)
{
common_feature_value
.
DeltaScore
(
value
)
=
0
;
}
return
true
;
}
else
{
return
false
;
}
}
// already decayed in shrink
case
3
:
{
// do this after save, because it must not be modified when retry
// common_feature_value.UnseenDays(value)++;
return
true
;
}
// save revert batch_model
case
5
:
{
return
true
;
}
default:
return
true
;
}
}
void
CtrCommonAccessor
::
UpdateStatAfterSave
(
float
*
value
,
int
param
)
{
auto
base_threshold
=
_config
.
ctr_accessor_param
().
base_threshold
();
auto
delta_threshold
=
_config
.
ctr_accessor_param
().
delta_threshold
();
auto
delta_keep_days
=
_config
.
ctr_accessor_param
().
delta_keep_days
();
if
(
param
==
2
)
{
delta_threshold
=
0
;
}
switch
(
param
)
{
case
1
:
{
if
(
ShowClickScore
(
common_feature_value
.
Show
(
value
),
common_feature_value
.
Click
(
value
))
>=
base_threshold
&&
common_feature_value
.
DeltaScore
(
value
)
>=
delta_threshold
&&
common_feature_value
.
UnseenDays
(
value
)
<=
delta_keep_days
)
{
common_feature_value
.
DeltaScore
(
value
)
=
0
;
}
}
return
;
case
3
:
{
common_feature_value
.
UnseenDays
(
value
)
++
;
}
return
;
default:
return
;
}
}
int32_t
CtrCommonAccessor
::
Create
(
float
**
values
,
size_t
num
)
{
for
(
size_t
value_item
=
0
;
value_item
<
num
;
++
value_item
)
{
float
*
value
=
values
[
value_item
];
value
[
common_feature_value
.
UnseenDaysIndex
()]
=
0
;
value
[
common_feature_value
.
DeltaScoreIndex
()]
=
0
;
value
[
common_feature_value
.
ShowIndex
()]
=
0
;
value
[
common_feature_value
.
ClickIndex
()]
=
0
;
value
[
common_feature_value
.
SlotIndex
()]
=
-
1
;
bool
zero_init
=
_config
.
ctr_accessor_param
().
zero_init
();
_embed_sgd_rule
->
InitValue
(
value
+
common_feature_value
.
EmbedWIndex
(),
value
+
common_feature_value
.
EmbedG2SumIndex
(),
zero_init
);
_embedx_sgd_rule
->
InitValue
(
value
+
common_feature_value
.
EmbedxWIndex
(),
value
+
common_feature_value
.
EmbedxG2SumIndex
(),
false
);
}
return
0
;
}
bool
CtrCommonAccessor
::
NeedExtendMF
(
float
*
value
)
{
float
show
=
value
[
common_feature_value
.
ShowIndex
()];
float
click
=
value
[
common_feature_value
.
ClickIndex
()];
float
score
=
(
show
-
click
)
*
_config
.
ctr_accessor_param
().
nonclk_coeff
()
+
click
*
_config
.
ctr_accessor_param
().
click_coeff
();
return
score
>=
_config
.
embedx_threshold
();
}
bool
CtrCommonAccessor
::
HasMF
(
int
size
)
{
return
size
>
common_feature_value
.
EmbedxG2SumIndex
();
}
// from CommonFeatureValue to CtrCommonPullValue
int32_t
CtrCommonAccessor
::
Select
(
float
**
select_values
,
const
float
**
values
,
size_t
num
)
{
auto
embedx_dim
=
_config
.
embedx_dim
();
for
(
size_t
value_item
=
0
;
value_item
<
num
;
++
value_item
)
{
float
*
select_value
=
select_values
[
value_item
];
const
float
*
value
=
values
[
value_item
];
select_value
[
CtrCommonPullValue
::
ShowIndex
()]
=
value
[
common_feature_value
.
ShowIndex
()];
select_value
[
CtrCommonPullValue
::
ClickIndex
()]
=
value
[
common_feature_value
.
ClickIndex
()];
select_value
[
CtrCommonPullValue
::
EmbedWIndex
()]
=
value
[
common_feature_value
.
EmbedWIndex
()];
memcpy
(
select_value
+
CtrCommonPullValue
::
EmbedxWIndex
(),
value
+
common_feature_value
.
EmbedxWIndex
(),
embedx_dim
*
sizeof
(
float
));
}
return
0
;
}
// from CtrCommonPushValue to CtrCommonPushValue
// first dim: item
// second dim: field num
int32_t
CtrCommonAccessor
::
Merge
(
float
**
update_values
,
const
float
**
other_update_values
,
size_t
num
)
{
auto
embedx_dim
=
_config
.
embedx_dim
();
int
total_dim
=
CtrCommonPushValue
::
Dim
(
embedx_dim
);
for
(
size_t
value_item
=
0
;
value_item
<
num
;
++
value_item
)
{
float
*
update_value
=
update_values
[
value_item
];
const
float
*
other_update_value
=
other_update_values
[
value_item
];
for
(
int
i
=
0
;
i
<
total_dim
;
++
i
)
{
if
(
i
!=
CtrCommonPushValue
::
SlotIndex
())
{
update_value
[
i
]
+=
other_update_value
[
i
];
}
}
}
return
0
;
}
// from CtrCommonPushValue to CommonFeatureValue
// first dim: item
// second dim: field num
int32_t
CtrCommonAccessor
::
Update
(
float
**
update_values
,
const
float
**
push_values
,
size_t
num
)
{
for
(
size_t
value_item
=
0
;
value_item
<
num
;
++
value_item
)
{
float
*
update_value
=
update_values
[
value_item
];
const
float
*
push_value
=
push_values
[
value_item
];
float
push_show
=
push_value
[
CtrCommonPushValue
::
ShowIndex
()];
float
push_click
=
push_value
[
CtrCommonPushValue
::
ClickIndex
()];
float
slot
=
push_value
[
CtrCommonPushValue
::
SlotIndex
()];
update_value
[
common_feature_value
.
ShowIndex
()]
+=
push_show
;
update_value
[
common_feature_value
.
ClickIndex
()]
+=
push_click
;
update_value
[
common_feature_value
.
SlotIndex
()]
=
slot
;
update_value
[
common_feature_value
.
DeltaScoreIndex
()]
+=
(
push_show
-
push_click
)
*
_config
.
ctr_accessor_param
().
nonclk_coeff
()
+
push_click
*
_config
.
ctr_accessor_param
().
click_coeff
();
update_value
[
common_feature_value
.
UnseenDaysIndex
()]
=
0
;
// TODO(zhaocaibei123): add configure show_scale
if
(
!
_show_scale
)
{
push_show
=
1
;
}
VLOG
(
3
)
<<
"accessor show scale:"
<<
_show_scale
<<
", push_show:"
<<
push_show
;
_embed_sgd_rule
->
UpdateValue
(
update_value
+
common_feature_value
.
EmbedWIndex
(),
update_value
+
common_feature_value
.
EmbedG2SumIndex
(),
push_value
+
CtrCommonPushValue
::
EmbedGIndex
(),
push_show
);
_embedx_sgd_rule
->
UpdateValue
(
update_value
+
common_feature_value
.
EmbedxWIndex
(),
update_value
+
common_feature_value
.
EmbedxG2SumIndex
(),
push_value
+
CtrCommonPushValue
::
EmbedxGIndex
(),
push_show
);
}
return
0
;
}
bool
CtrCommonAccessor
::
CreateValue
(
int
stage
,
const
float
*
value
)
{
// stage == 0, pull
// stage == 1, push
if
(
stage
==
0
)
{
return
true
;
}
else
if
(
stage
==
1
)
{
// operation
auto
show
=
CtrCommonPushValue
::
Show
(
const_cast
<
float
*>
(
value
));
auto
click
=
CtrCommonPushValue
::
Click
(
const_cast
<
float
*>
(
value
));
auto
score
=
ShowClickScore
(
show
,
click
);
if
(
score
<=
0
)
{
return
false
;
}
if
(
score
>=
1
)
{
return
true
;
}
return
local_uniform_real_distribution
<
float
>
()(
local_random_engine
())
<
score
;
}
else
{
return
true
;
}
}
float
CtrCommonAccessor
::
ShowClickScore
(
float
show
,
float
click
)
{
auto
nonclk_coeff
=
_config
.
ctr_accessor_param
().
nonclk_coeff
();
auto
click_coeff
=
_config
.
ctr_accessor_param
().
click_coeff
();
return
(
show
-
click
)
*
nonclk_coeff
+
click
*
click_coeff
;
}
std
::
string
CtrCommonAccessor
::
ParseToString
(
const
float
*
v
,
int
param
)
{
thread_local
std
::
ostringstream
os
;
os
.
clear
();
os
.
str
(
""
);
os
<<
v
[
0
]
<<
" "
<<
v
[
1
]
<<
" "
<<
v
[
2
]
<<
" "
<<
v
[
3
]
<<
" "
<<
v
[
4
]
<<
" "
<<
v
[
5
];
for
(
int
i
=
common_feature_value
.
EmbedG2SumIndex
();
i
<
common_feature_value
.
EmbedxWIndex
();
i
++
)
{
os
<<
" "
<<
v
[
i
];
}
auto
show
=
common_feature_value
.
Show
(
const_cast
<
float
*>
(
v
));
auto
click
=
common_feature_value
.
Click
(
const_cast
<
float
*>
(
v
));
auto
score
=
ShowClickScore
(
show
,
click
);
if
(
score
>=
_config
.
embedx_threshold
()
&&
param
>
common_feature_value
.
EmbedxWIndex
())
{
for
(
auto
i
=
common_feature_value
.
EmbedxWIndex
();
i
<
common_feature_value
.
Dim
();
++
i
)
{
os
<<
" "
<<
v
[
i
];
}
}
return
os
.
str
();
}
int
CtrCommonAccessor
::
ParseFromString
(
const
std
::
string
&
str
,
float
*
value
)
{
_embedx_sgd_rule
->
InitValue
(
value
+
common_feature_value
.
EmbedxWIndex
(),
value
+
common_feature_value
.
EmbedxG2SumIndex
());
auto
ret
=
paddle
::
string
::
str_to_float
(
str
.
data
(),
value
);
CHECK
(
ret
>=
6
)
<<
"expect more than 6 real:"
<<
ret
;
return
ret
;
}
}
// namespace distributed
}
// namespace paddle
Prev
1
…
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment