Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
de2e6515
Commit
de2e6515
authored
Apr 26, 2023
by
yuguo960516yuguo
Browse files
2.4.1-dtk-23.04
parent
ad08b8ce
Pipeline
#228
failed with stages
in 0 seconds
Changes
272
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5330 additions
and
0 deletions
+5330
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+112
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+123
-0
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
.../distributed/fleet_executor/test/sink_interceptor_test.cc
+91
-0
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
...istributed/fleet_executor/test/source_interceptor_test.cc
+85
-0
paddle/fluid/distributed/index_dataset/CMakeLists.txt
paddle/fluid/distributed/index_dataset/CMakeLists.txt
+19
-0
paddle/fluid/distributed/index_dataset/index_dataset.proto
paddle/fluid/distributed/index_dataset/index_dataset.proto
+33
-0
paddle/fluid/distributed/index_dataset/index_sampler.cc
paddle/fluid/distributed/index_dataset/index_sampler.cc
+138
-0
paddle/fluid/distributed/index_dataset/index_sampler.h
paddle/fluid/distributed/index_dataset/index_sampler.h
+139
-0
paddle/fluid/distributed/index_dataset/index_wrapper.cc
paddle/fluid/distributed/index_dataset/index_wrapper.cc
+202
-0
paddle/fluid/distributed/index_dataset/index_wrapper.h
paddle/fluid/distributed/index_dataset/index_wrapper.h
+125
-0
paddle/fluid/distributed/ps/CMakeLists.txt
paddle/fluid/distributed/ps/CMakeLists.txt
+4
-0
paddle/fluid/distributed/ps/README.md
paddle/fluid/distributed/ps/README.md
+39
-0
paddle/fluid/distributed/ps/service/CMakeLists.txt
paddle/fluid/distributed/ps/service/CMakeLists.txt
+122
-0
paddle/fluid/distributed/ps/service/README.md
paddle/fluid/distributed/ps/service/README.md
+8
-0
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
+2067
-0
paddle/fluid/distributed/ps/service/brpc_ps_client.h
paddle/fluid/distributed/ps/service/brpc_ps_client.h
+435
-0
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
+899
-0
paddle/fluid/distributed/ps/service/brpc_ps_server.h
paddle/fluid/distributed/ps/service/brpc_ps_server.h
+240
-0
paddle/fluid/distributed/ps/service/brpc_utils.cc
paddle/fluid/distributed/ps/service/brpc_utils.cc
+351
-0
paddle/fluid/distributed/ps/service/brpc_utils.h
paddle/fluid/distributed/ps/service/brpc_utils.h
+98
-0
No files found.
Too many changes to show.
To preserve performance only
272 of 272+
files are displayed.
Plain diff
Email patch
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
0 → 100644
View file @
de2e6515
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
void
LinkNodes
(
const
std
::
vector
<
TaskNode
*>&
nodes
)
{
size_t
size
=
nodes
.
size
();
if
(
size
<=
1
)
return
;
{
// i = 0
TaskNode
*
now
=
nodes
[
0
];
TaskNode
*
next
=
nodes
[
1
];
now
->
AddDownstreamTask
(
next
->
task_id
());
}
{
// i = size - 1
TaskNode
*
prev
=
nodes
[
size
-
2
];
TaskNode
*
now
=
nodes
[
size
-
1
];
now
->
AddUpstreamTask
(
prev
->
task_id
());
}
for
(
size_t
i
=
1
;
i
<
size
-
1
;
++
i
)
{
TaskNode
*
prev
=
nodes
[
i
-
1
];
TaskNode
*
now
=
nodes
[
i
];
TaskNode
*
next
=
nodes
[
i
+
1
];
now
->
AddUpstreamTask
(
prev
->
task_id
());
now
->
AddDownstreamTask
(
next
->
task_id
());
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
int64_t
micro_steps
=
3
;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
1
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
1
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
1
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
1
,
0
);
TaskNode
*
node_e
=
new
TaskNode
(
0
,
0
,
4
,
1
,
0
);
TaskNode
*
node_f
=
new
TaskNode
(
0
,
0
,
5
,
1
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
// source->a->b->c->d->e->f->sink
LinkNodes
({
source
,
node_a
,
node_b
,
node_c
,
node_d
,
node_e
,
node_f
,
sink
});
// LR->b(1:3)->F->B->e(3:1)->U
node_b
->
SetReplyUpPerSteps
(
micro_steps
);
node_e
->
SetSendDownPerSteps
(
micro_steps
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Amplifier"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Compute"
,
3
,
node_d
));
carrier
->
SetInterceptor
(
4
,
InterceptorFactory
::
Create
(
"Amplifier"
,
4
,
node_e
));
carrier
->
SetInterceptor
(
5
,
InterceptorFactory
::
Create
(
"Compute"
,
5
,
node_f
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
0 → 100644
View file @
de2e6515
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
int64_t
GetBuffSize
(
const
std
::
map
<
std
::
pair
<
TaskNode
*
,
TaskNode
*>
,
int64_t
>
buffs
,
TaskNode
*
from
,
TaskNode
*
to
)
{
if
(
buffs
.
find
({
from
,
to
})
!=
buffs
.
end
())
{
return
buffs
.
at
({
from
,
to
});
}
if
(
buffs
.
find
({
to
,
from
})
!=
buffs
.
end
())
{
return
buffs
.
at
({
to
,
from
});
}
return
2
;
// set default 2
}
void
LinkNodes
(
const
std
::
vector
<
TaskNode
*>&
nodes
,
const
std
::
map
<
std
::
pair
<
TaskNode
*
,
TaskNode
*>
,
int64_t
>
buffs
)
{
size_t
size
=
nodes
.
size
();
if
(
size
<=
1
)
return
;
{
// i = 0
TaskNode
*
now
=
nodes
[
0
];
TaskNode
*
next
=
nodes
[
1
];
auto
buff_size
=
GetBuffSize
(
buffs
,
now
,
next
);
now
->
AddDownstreamTask
(
next
->
task_id
(),
buff_size
);
}
{
// i = size - 1
TaskNode
*
prev
=
nodes
[
size
-
2
];
TaskNode
*
now
=
nodes
[
size
-
1
];
auto
buff_size
=
GetBuffSize
(
buffs
,
prev
,
now
);
now
->
AddUpstreamTask
(
prev
->
task_id
(),
buff_size
);
}
for
(
size_t
i
=
1
;
i
<
size
-
1
;
++
i
)
{
TaskNode
*
prev
=
nodes
[
i
-
1
];
TaskNode
*
now
=
nodes
[
i
];
TaskNode
*
next
=
nodes
[
i
+
1
];
auto
buff_size
=
GetBuffSize
(
buffs
,
prev
,
now
);
now
->
AddUpstreamTask
(
prev
->
task_id
(),
buff_size
);
buff_size
=
GetBuffSize
(
buffs
,
now
,
next
);
now
->
AddDownstreamTask
(
next
->
task_id
(),
buff_size
);
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
int64_t
micro_steps
=
6
;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
micro_steps
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
micro_steps
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
// source->a->b->c->d->sink
// LR->F->B->U
LinkNodes
({
source
,
node_a
,
node_b
,
node_c
,
node_d
,
sink
},
{{{
node_b
,
node_c
},
1
}});
node_a
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunAtOffset
(
micro_steps
-
1
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Amplifier"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Amplifier"
,
3
,
node_d
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
class
FakeInterceptor
:
public
Interceptor
{
public:
FakeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
std
::
cout
<<
"FakeInterceptor run in scope "
<<
msg
.
scope_idx
()
<<
std
::
endl
;
InterceptorMessage
reply
;
reply
.
set_message_type
(
DATA_IS_USELESS
);
Send
(
SOURCE_ID
,
reply
);
InterceptorMessage
ready
;
ready
.
set_message_type
(
DATA_IS_READY
);
Send
(
SINK_ID
,
ready
);
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
std
::
cout
<<
"FakeInterceptor remove result in scope "
<<
msg
.
scope_idx
()
<<
std
::
endl
;
}
}
private:
int64_t
step_
;
};
TEST
(
SourceInterceptor
,
Source
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
node_a
->
AddDownstreamTask
(
SINK_ID
,
1
);
sink
->
AddUpstreamTask
(
0
,
1
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
FakeInterceptor
>
(
0
,
node_a
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
class
FakeInterceptor
:
public
Interceptor
{
public:
FakeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
step_
=
0
;
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
std
::
cout
<<
"FakeInterceptor run in scope "
<<
msg
.
scope_idx
()
<<
std
::
endl
;
InterceptorMessage
reply
;
reply
.
set_message_type
(
DATA_IS_USELESS
);
Send
(
SOURCE_ID
,
reply
);
step_
++
;
if
(
step_
==
node_
->
max_run_times
())
{
carrier_
->
WakeUp
();
}
}
}
private:
int64_t
step_
;
};
TEST
(
SourceInterceptor
,
Source
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
FakeInterceptor
>
(
0
,
node_a
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/index_dataset/CMakeLists.txt
0 → 100644
View file @
de2e6515
proto_library
(
index_dataset_proto SRCS index_dataset.proto
)
cc_library
(
index_wrapper
SRCS index_wrapper.cc
DEPS index_dataset_proto fs
)
if
(
WITH_MKLDNN
)
cc_library
(
index_sampler
SRCS index_sampler.cc
DEPS xxhash index_wrapper eigen3 mkldnn
)
else
()
cc_library
(
index_sampler
SRCS index_sampler.cc
DEPS xxhash index_wrapper eigen3
)
endif
()
if
(
WITH_PYTHON
)
py_proto_compile
(
index_dataset_py_proto SRCS index_dataset.proto
)
endif
()
paddle/fluid/distributed/index_dataset/index_dataset.proto
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax
=
"proto2"
;
package
paddle
.
distributed
;
message
IndexNode
{
required
uint64
id
=
1
;
required
bool
is_leaf
=
2
;
required
float
probability
=
3
;
optional
string
item_name
=
4
;
}
message
TreeMeta
{
required
int32
height
=
1
;
required
int32
branch
=
2
;
}
message
KVItem
{
required
bytes
key
=
1
;
required
bytes
value
=
2
;
}
paddle/fluid/distributed/index_dataset/index_sampler.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
distributed
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
LayerWiseSampler
::
sample
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>&
user_inputs
,
const
std
::
vector
<
uint64_t
>&
target_ids
,
bool
with_hierarchy
)
{
auto
input_num
=
target_ids
.
size
();
auto
user_feature_num
=
user_inputs
[
0
].
size
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
outputs
(
input_num
*
layer_counts_sum_
,
std
::
vector
<
uint64_t
>
(
user_feature_num
+
2
));
auto
max_layer
=
tree_
->
Height
();
size_t
idx
=
0
;
for
(
size_t
i
=
0
;
i
<
input_num
;
i
++
)
{
auto
travel_codes
=
tree_
->
GetTravelCodes
(
target_ids
[
i
],
start_sample_layer_
);
auto
travel_path
=
tree_
->
GetNodes
(
travel_codes
);
for
(
size_t
j
=
0
;
j
<
travel_path
.
size
();
j
++
)
{
// user
if
(
j
>
0
&&
with_hierarchy
)
{
auto
ancestor_codes
=
tree_
->
GetAncestorCodes
(
user_inputs
[
i
],
max_layer
-
j
-
1
);
auto
hierarchical_user
=
tree_
->
GetNodes
(
ancestor_codes
);
for
(
int
idx_offset
=
0
;
idx_offset
<=
layer_counts_
[
j
];
idx_offset
++
)
{
for
(
size_t
k
=
0
;
k
<
user_feature_num
;
k
++
)
{
outputs
[
idx
+
idx_offset
][
k
]
=
hierarchical_user
[
k
].
id
();
}
}
}
else
{
for
(
int
idx_offset
=
0
;
idx_offset
<=
layer_counts_
[
j
];
idx_offset
++
)
{
for
(
size_t
k
=
0
;
k
<
user_feature_num
;
k
++
)
{
outputs
[
idx
+
idx_offset
][
k
]
=
user_inputs
[
i
][
k
];
}
}
}
// sampler ++
outputs
[
idx
][
user_feature_num
]
=
travel_path
[
j
].
id
();
outputs
[
idx
][
user_feature_num
+
1
]
=
1.0
;
idx
+=
1
;
for
(
int
idx_offset
=
0
;
idx_offset
<
layer_counts_
[
j
];
idx_offset
++
)
{
int
sample_res
=
0
;
do
{
sample_res
=
sampler_vec_
[
j
]
->
Sample
();
}
while
(
layer_ids_
[
j
][
sample_res
].
id
()
==
travel_path
[
j
].
id
());
outputs
[
idx
+
idx_offset
][
user_feature_num
]
=
layer_ids_
[
j
][
sample_res
].
id
();
outputs
[
idx
+
idx_offset
][
user_feature_num
+
1
]
=
0
;
}
idx
+=
layer_counts_
[
j
];
}
}
return
outputs
;
}
void
LayerWiseSampler
::
sample_from_dataset
(
const
uint16_t
sample_slot
,
std
::
vector
<
paddle
::
framework
::
Record
>*
src_datas
,
std
::
vector
<
paddle
::
framework
::
Record
>*
sample_results
)
{
sample_results
->
clear
();
for
(
auto
&
data
:
*
src_datas
)
{
VLOG
(
1
)
<<
"src data size = "
<<
src_datas
->
size
();
VLOG
(
1
)
<<
"float data size = "
<<
data
.
float_feasigns_
.
size
();
// data.Print();
uint64_t
start_idx
=
sample_results
->
size
();
VLOG
(
1
)
<<
"before sample, sample_results.size = "
<<
start_idx
;
uint64_t
sample_feasign_idx
=
-
1
;
bool
sample_sign
=
false
;
for
(
unsigned
int
i
=
0
;
i
<
data
.
uint64_feasigns_
.
size
();
i
++
)
{
VLOG
(
1
)
<<
"slot"
<<
i
<<
" = "
<<
data
.
uint64_feasigns_
[
i
].
slot
();
if
(
data
.
uint64_feasigns_
[
i
].
slot
()
==
sample_slot
)
{
sample_sign
=
true
;
sample_feasign_idx
=
i
;
}
if
(
sample_sign
)
break
;
}
VLOG
(
1
)
<<
"sample_feasign_idx: "
<<
sample_feasign_idx
;
if
(
sample_sign
)
{
auto
target_id
=
data
.
uint64_feasigns_
[
sample_feasign_idx
].
sign
().
uint64_feasign_
;
auto
travel_codes
=
tree_
->
GetTravelCodes
(
target_id
,
start_sample_layer_
);
auto
travel_path
=
tree_
->
GetNodes
(
travel_codes
);
for
(
unsigned
int
j
=
0
;
j
<
travel_path
.
size
();
j
++
)
{
paddle
::
framework
::
Record
instance
(
data
);
instance
.
uint64_feasigns_
[
sample_feasign_idx
].
sign
().
uint64_feasign_
=
travel_path
[
j
].
id
();
sample_results
->
push_back
(
instance
);
for
(
int
idx_offset
=
0
;
idx_offset
<
layer_counts_
[
j
];
idx_offset
++
)
{
int
sample_res
=
0
;
do
{
sample_res
=
sampler_vec_
[
j
]
->
Sample
();
}
while
(
layer_ids_
[
j
][
sample_res
].
id
()
==
travel_path
[
j
].
id
());
paddle
::
framework
::
Record
instance
(
data
);
instance
.
uint64_feasigns_
[
sample_feasign_idx
].
sign
().
uint64_feasign_
=
layer_ids_
[
j
][
sample_res
].
id
();
VLOG
(
1
)
<<
"layer id :"
<<
layer_ids_
[
j
][
sample_res
].
id
();
// sample_feasign_idx + 1 == label's id
instance
.
uint64_feasigns_
[
sample_feasign_idx
+
1
]
.
sign
()
.
uint64_feasign_
=
0
;
sample_results
->
push_back
(
instance
);
}
VLOG
(
1
)
<<
"layer end!!!!!!!!!!!!!!!!!!"
;
}
}
}
VLOG
(
1
)
<<
"after sample, sample_results.size = "
<<
sample_results
->
size
();
return
;
}
std
::
vector
<
uint64_t
>
float2int
(
std
::
vector
<
double
>
tmp
)
{
std
::
vector
<
uint64_t
>
tmp_int
;
for
(
auto
i
:
tmp
)
tmp_int
.
push_back
(
uint64_t
(
i
));
return
tmp_int
;
}
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/index_dataset/index_sampler.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/math/sampler.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
class
IndexSampler
{
public:
virtual
~
IndexSampler
()
{}
IndexSampler
()
{}
template
<
typename
T
>
static
std
::
shared_ptr
<
IndexSampler
>
Init
(
const
std
::
string
&
name
)
{
std
::
shared_ptr
<
IndexSampler
>
instance
=
nullptr
;
instance
.
reset
(
new
T
(
name
));
return
instance
;
}
virtual
void
init_layerwise_conf
(
const
std
::
vector
<
uint16_t
>&
layer_sample_counts
,
uint16_t
start_sample_layer
=
1
,
uint16_t
seed
=
0
)
{}
virtual
void
init_beamsearch_conf
(
const
int64_t
k
)
{}
virtual
std
::
vector
<
std
::
vector
<
uint64_t
>>
sample
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>&
user_inputs
,
const
std
::
vector
<
uint64_t
>&
input_targets
,
bool
with_hierarchy
=
false
)
=
0
;
virtual
void
sample_from_dataset
(
const
uint16_t
sample_slot
,
std
::
vector
<
paddle
::
framework
::
Record
>*
src_datas
,
std
::
vector
<
paddle
::
framework
::
Record
>*
sample_results
)
=
0
;
};
class
LayerWiseSampler
:
public
IndexSampler
{
public:
virtual
~
LayerWiseSampler
()
{}
explicit
LayerWiseSampler
(
const
std
::
string
&
name
)
{
tree_
=
IndexWrapper
::
GetInstance
()
->
get_tree_index
(
name
);
}
void
init_layerwise_conf
(
const
std
::
vector
<
uint16_t
>&
layer_sample_counts
,
uint16_t
start_sample_layer
,
uint16_t
seed
)
override
{
seed_
=
seed
;
start_sample_layer_
=
start_sample_layer
;
PADDLE_ENFORCE_GT
(
start_sample_layer_
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"start sampler layer = [%d], it should greater than 0."
,
start_sample_layer_
));
PADDLE_ENFORCE_LT
(
start_sample_layer_
,
tree_
->
Height
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"start sampler layer = [%d], it should less than "
"max_layer, which is [%d]."
,
start_sample_layer_
,
tree_
->
Height
()));
size_t
i
=
0
;
layer_counts_sum_
=
0
;
layer_counts_
.
clear
();
int
cur_layer
=
start_sample_layer_
;
while
(
cur_layer
<
tree_
->
Height
())
{
int
layer_sample_num
=
1
;
if
(
i
<
layer_sample_counts
.
size
())
{
layer_sample_num
=
layer_sample_counts
[
i
];
}
layer_counts_sum_
+=
layer_sample_num
+
1
;
layer_counts_
.
push_back
(
layer_sample_num
);
VLOG
(
3
)
<<
"[INFO] level "
<<
cur_layer
<<
" sample_layer_counts.push_back: "
<<
layer_sample_num
;
cur_layer
+=
1
;
i
+=
1
;
}
reverse
(
layer_counts_
.
begin
(),
layer_counts_
.
end
());
VLOG
(
3
)
<<
"sample counts sum: "
<<
layer_counts_sum_
;
auto
max_layer
=
tree_
->
Height
();
sampler_vec_
.
clear
();
layer_ids_
.
clear
();
auto
layer_index
=
max_layer
-
1
;
size_t
idx
=
0
;
while
(
layer_index
>=
start_sample_layer_
)
{
auto
layer_codes
=
tree_
->
GetLayerCodes
(
layer_index
);
layer_ids_
.
push_back
(
tree_
->
GetNodes
(
layer_codes
));
auto
sampler_temp
=
std
::
make_shared
<
paddle
::
operators
::
math
::
UniformSampler
>
(
layer_ids_
[
idx
].
size
()
-
1
,
seed_
);
sampler_vec_
.
push_back
(
sampler_temp
);
layer_index
--
;
idx
++
;
}
}
std
::
vector
<
std
::
vector
<
uint64_t
>>
sample
(
const
std
::
vector
<
std
::
vector
<
uint64_t
>>&
user_inputs
,
const
std
::
vector
<
uint64_t
>&
target_ids
,
bool
with_hierarchy
)
override
;
void
sample_from_dataset
(
const
uint16_t
sample_slot
,
std
::
vector
<
paddle
::
framework
::
Record
>*
src_datas
,
std
::
vector
<
paddle
::
framework
::
Record
>*
sample_results
)
override
;
private:
std
::
vector
<
int
>
layer_counts_
;
int64_t
layer_counts_sum_
{
0
};
std
::
shared_ptr
<
TreeIndex
>
tree_
{
nullptr
};
int
seed_
{
0
};
int
start_sample_layer_
{
1
};
std
::
vector
<
std
::
shared_ptr
<
paddle
::
operators
::
math
::
Sampler
>>
sampler_vec_
;
std
::
vector
<
std
::
vector
<
IndexNode
>>
layer_ids_
;
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/index_dataset/index_wrapper.cc
0 → 100644
View file @
de2e6515
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/io/fs.h"
namespace
paddle
{
namespace
distributed
{
std
::
shared_ptr
<
IndexWrapper
>
IndexWrapper
::
s_instance_
(
nullptr
);
int
TreeIndex
::
Load
(
const
std
::
string
filename
)
{
int
err_no
;
auto
fp
=
paddle
::
framework
::
fs_open_read
(
filename
,
&
err_no
,
""
);
PADDLE_ENFORCE_NE
(
fp
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Open file %s failed. Please check whether the file exists."
,
filename
));
int
num
=
0
;
max_id_
=
0
;
fake_node_
.
set_id
(
0
);
fake_node_
.
set_is_leaf
(
false
);
fake_node_
.
set_probability
(
0.0
);
max_code_
=
0
;
size_t
ret
=
fread
(
&
num
,
sizeof
(
num
),
1
,
fp
.
get
());
while
(
ret
==
1
&&
num
>
0
)
{
std
::
string
content
(
num
,
'\0'
);
size_t
read_num
=
fread
(
const_cast
<
char
*>
(
content
.
data
()),
1
,
num
,
fp
.
get
());
PADDLE_ENFORCE_EQ
(
read_num
,
static_cast
<
size_t
>
(
num
),
platform
::
errors
::
InvalidArgument
(
"Read from file: %s failed. Valid Format is "
"an integer representing the length of the following string, "
"and the string itself.We got an iteger[% d], "
"but the following string's length is [%d]."
,
filename
,
num
,
read_num
));
KVItem
item
;
PADDLE_ENFORCE_EQ
(
item
.
ParseFromString
(
content
),
true
,
platform
::
errors
::
InvalidArgument
(
"Parse from file: %s failed. It's "
"content can't be parsed by KVItem."
,
filename
));
if
(
item
.
key
()
==
".tree_meta"
)
{
meta_
.
ParseFromString
(
item
.
value
());
}
else
{
auto
code
=
std
::
stoull
(
item
.
key
());
IndexNode
node
;
node
.
ParseFromString
(
item
.
value
());
// PADDLE_ENFORCE_NE(node.id(), 0,
// platform::errors::InvalidArgument(
// "Node'id should not be equel to zero."));
if
(
node
.
is_leaf
())
{
id_codes_map_
[
node
.
id
()]
=
code
;
}
data_
[
code
]
=
node
;
if
(
node
.
id
()
>
max_id_
)
{
max_id_
=
node
.
id
();
}
if
(
code
>
max_code_
)
{
max_code_
=
code
;
}
}
ret
=
fread
(
&
num
,
sizeof
(
num
),
1
,
fp
.
get
());
}
total_nodes_num_
=
data_
.
size
();
max_code_
+=
1
;
return
0
;
}
std
::
vector
<
IndexNode
>
TreeIndex
::
GetNodes
(
const
std
::
vector
<
uint64_t
>&
codes
)
{
std
::
vector
<
IndexNode
>
nodes
;
nodes
.
reserve
(
codes
.
size
());
for
(
size_t
i
=
0
;
i
<
codes
.
size
();
i
++
)
{
if
(
CheckIsValid
(
codes
[
i
]))
{
nodes
.
push_back
(
data_
.
at
(
codes
[
i
]));
}
else
{
nodes
.
push_back
(
fake_node_
);
}
}
return
nodes
;
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetLayerCodes
(
int
level
)
{
uint64_t
level_num
=
static_cast
<
uint64_t
>
(
std
::
pow
(
meta_
.
branch
(),
level
));
uint64_t
level_offset
=
level_num
-
1
;
std
::
vector
<
uint64_t
>
res
;
res
.
reserve
(
level_num
);
for
(
uint64_t
i
=
0
;
i
<
level_num
;
i
++
)
{
auto
code
=
level_offset
+
i
;
if
(
CheckIsValid
(
code
))
{
res
.
push_back
(
code
);
}
}
return
res
;
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetAncestorCodes
(
const
std
::
vector
<
uint64_t
>&
ids
,
int
level
)
{
std
::
vector
<
uint64_t
>
res
;
res
.
reserve
(
ids
.
size
());
int
cur_level
;
for
(
size_t
i
=
0
;
i
<
ids
.
size
();
i
++
)
{
if
(
id_codes_map_
.
find
(
ids
[
i
])
==
id_codes_map_
.
end
())
{
res
.
push_back
(
max_code_
);
}
else
{
auto
code
=
id_codes_map_
.
at
(
ids
[
i
]);
cur_level
=
meta_
.
height
()
-
1
;
while
(
level
>=
0
&&
cur_level
>
level
)
{
code
=
(
code
-
1
)
/
meta_
.
branch
();
cur_level
--
;
}
res
.
push_back
(
code
);
}
}
return
res
;
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetChildrenCodes
(
uint64_t
ancestor
,
int
level
)
{
auto
level_code_num
=
static_cast
<
uint64_t
>
(
std
::
pow
(
meta_
.
branch
(),
level
));
auto
code_min
=
level_code_num
-
1
;
auto
code_max
=
meta_
.
branch
()
*
level_code_num
-
1
;
std
::
vector
<
uint64_t
>
parent
;
parent
.
push_back
(
ancestor
);
std
::
vector
<
uint64_t
>
res
;
size_t
p_idx
=
0
;
while
(
true
)
{
size_t
p_size
=
parent
.
size
();
for
(;
p_idx
<
p_size
;
p_idx
++
)
{
for
(
int
i
=
0
;
i
<
meta_
.
branch
();
i
++
)
{
auto
code
=
parent
[
p_idx
]
*
meta_
.
branch
()
+
i
+
1
;
if
(
data_
.
find
(
code
)
!=
data_
.
end
())
parent
.
push_back
(
code
);
}
}
if
((
code_min
<=
parent
[
p_idx
])
&&
(
parent
[
p_idx
]
<
code_max
))
{
break
;
}
}
return
std
::
vector
<
uint64_t
>
(
parent
.
begin
()
+
p_idx
,
parent
.
end
());
}
std
::
vector
<
uint64_t
>
TreeIndex
::
GetTravelCodes
(
uint64_t
id
,
int
start_level
)
{
std
::
vector
<
uint64_t
>
res
;
PADDLE_ENFORCE_NE
(
id_codes_map_
.
find
(
id
),
id_codes_map_
.
end
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"id = %d doesn't exist in Tree."
,
id
));
auto
code
=
id_codes_map_
.
at
(
id
);
int
level
=
meta_
.
height
()
-
1
;
while
(
level
>=
start_level
)
{
res
.
push_back
(
code
);
code
=
(
code
-
1
)
/
meta_
.
branch
();
level
--
;
}
return
res
;
}
std
::
vector
<
IndexNode
>
TreeIndex
::
GetAllLeafs
()
{
std
::
vector
<
IndexNode
>
res
;
res
.
reserve
(
id_codes_map_
.
size
());
for
(
auto
&
ite
:
id_codes_map_
)
{
auto
code
=
ite
.
second
;
res
.
push_back
(
data_
.
at
(
code
));
}
return
res
;
}
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/index_dataset/index_wrapper.h
0 → 100644
View file @
de2e6515
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_dataset.pb.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
class
Index
{
public:
Index
()
{}
~
Index
()
{}
};
class
TreeIndex
:
public
Index
{
public:
TreeIndex
()
{}
~
TreeIndex
()
{}
int
Height
()
{
return
meta_
.
height
();
}
int
Branch
()
{
return
meta_
.
branch
();
}
uint64_t
TotalNodeNums
()
{
return
total_nodes_num_
;
}
uint64_t
EmbSize
()
{
return
max_id_
+
1
;
}
int
Load
(
const
std
::
string
path
);
inline
bool
CheckIsValid
(
int
code
)
{
if
(
data_
.
find
(
code
)
!=
data_
.
end
())
{
return
true
;
}
else
{
return
false
;
}
}
std
::
vector
<
IndexNode
>
GetNodes
(
const
std
::
vector
<
uint64_t
>&
codes
);
std
::
vector
<
uint64_t
>
GetLayerCodes
(
int
level
);
std
::
vector
<
uint64_t
>
GetAncestorCodes
(
const
std
::
vector
<
uint64_t
>&
ids
,
int
level
);
std
::
vector
<
uint64_t
>
GetChildrenCodes
(
uint64_t
ancestor
,
int
level
);
std
::
vector
<
uint64_t
>
GetTravelCodes
(
uint64_t
id
,
int
start_level
);
std
::
vector
<
IndexNode
>
GetAllLeafs
();
std
::
unordered_map
<
uint64_t
,
IndexNode
>
data_
;
std
::
unordered_map
<
uint64_t
,
uint64_t
>
id_codes_map_
;
uint64_t
total_nodes_num_
;
TreeMeta
meta_
;
uint64_t
max_id_
;
uint64_t
max_code_
;
IndexNode
fake_node_
;
};
using
TreePtr
=
std
::
shared_ptr
<
TreeIndex
>
;
class
IndexWrapper
{
public:
virtual
~
IndexWrapper
()
{}
IndexWrapper
()
{}
void
clear_tree
()
{
tree_map
.
clear
();
}
TreePtr
get_tree_index
(
const
std
::
string
name
)
{
PADDLE_ENFORCE_NE
(
tree_map
.
find
(
name
),
tree_map
.
end
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"tree [%s] doesn't exist. Please insert it firstly "
"by API[
\'
insert_tree_index
\'
]."
,
name
));
return
tree_map
[
name
];
}
void
insert_tree_index
(
const
std
::
string
name
,
const
std
::
string
tree_path
)
{
if
(
tree_map
.
find
(
name
)
!=
tree_map
.
end
())
{
VLOG
(
0
)
<<
"Tree "
<<
name
<<
" has already existed."
;
return
;
}
TreePtr
tree
=
std
::
make_shared
<
TreeIndex
>
();
int
ret
=
tree
->
Load
(
tree_path
);
PADDLE_ENFORCE_EQ
(
ret
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Load tree[%s] from path[%s] failed. Please "
"check whether the file exists."
,
name
,
tree_path
));
tree_map
.
insert
(
std
::
pair
<
std
::
string
,
TreePtr
>
{
name
,
tree
});
}
static
std
::
shared_ptr
<
IndexWrapper
>
GetInstancePtr
()
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
distributed
::
IndexWrapper
());
}
return
s_instance_
;
}
static
IndexWrapper
*
GetInstance
()
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
distributed
::
IndexWrapper
());
}
return
s_instance_
.
get
();
}
private:
static
std
::
shared_ptr
<
IndexWrapper
>
s_instance_
;
std
::
unordered_map
<
std
::
string
,
TreePtr
>
tree_map
;
};
}
// end namespace distributed
}
// end namespace paddle
paddle/fluid/distributed/ps/CMakeLists.txt
0 → 100644
View file @
de2e6515
set_property
(
GLOBAL PROPERTY RPC_DEPS sendrecv_rpc
${
BRPC_DEPS
}
string_helper
)
add_subdirectory
(
table
)
add_subdirectory
(
service
)
add_subdirectory
(
wrapper
)
paddle/fluid/distributed/ps/README.md
0 → 100644
View file @
de2e6515
# 目录说明
Table: for param storage and update
-----MemorySparseTable: table for sparse param, used in cpu async mode
-----MemoryDenseTable: table for dense param, used in cpu async/geo mode
-----MemorySparseGeoTable: table for sparse param, used in cpu async mode
-----CommonGraphTable: table used for graph learning
-----BarrierTable: table for barrier function, used in cpu sync mode
-----TensorTable: table which run program, used for learning rate decay only
ValueAccessor: for pull param and push gradient
-----CtrCommonAccessor: pull/push value with show/click, float type
-----CtrDoubleAccessor: same as CtrCommonAccessor, other than show/click with double type
-----SparseAccessor: used for common embedding, pull value without show/click, push value with show/click
-----CommMergeAccessor: used for dense table only, for get param dim
PsService(proto): for server to handle request
-----PsBaseService
----------BrpcPsService: for cpu dnn training task
----------GraphBrpcService: for graph learning
-----HeterService: for dnn training task with heterogeneous computing resources
PSServer: recv request from trainer and handle it by service
-----BrpcPsServer: for cpu dnn training task
-----GraphBrpcServer: for graph learning
-----PsLocalServer: for GpuPS
HeterServer: for HeterPS
PSClient: pull param and push gradient for trainer
-----BrpcPsClient: for cpu dnn training task
----------GraphBrpcClient: for graph learning
-----PsLocalClient: for GpuPS
HeterClient: for HeterPS
PSCore: Wrapper for InitServer
GraphPyService: for graph learning
paddle/fluid/distributed/ps/service/CMakeLists.txt
0 → 100644
View file @
de2e6515
set
(
BRPC_SRCS ps_client.cc server.cc
)
set_source_files_properties
(
${
BRPC_SRCS
}
)
if
(
WITH_HETERPS
)
set
(
BRPC_DEPS
brpc
ssl
crypto
protobuf
gflags
glog
zlib
leveldb
snappy
gflags
glog
device_context
rocksdb
)
else
()
set
(
BRPC_DEPS
brpc
ssl
crypto
protobuf
gflags
glog
zlib
leveldb
snappy
gflags
glog
device_context
)
endif
()
brpc_library
(
sendrecv_rpc
SRCS
${
BRPC_SRCS
}
PROTO
sendrecv.proto
DEPS
${
BRPC_DEPS
}
)
#set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
get_property
(
RPC_DEPS GLOBAL PROPERTY RPC_DEPS
)
set_source_files_properties
(
communicator/communicator.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ps_service/service.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
brpc_ps_server.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
brpc_ps_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ps_local_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
brpc_utils.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
heter_server.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
heter_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ps_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
server.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
graph_brpc_server.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
graph_brpc_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
coordinator_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
ps_service/graph_py_service.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
brpc_utils
SRCS brpc_utils.cc
DEPS tensor device_context
${
COMMON_DEPS
}
${
RPC_DEPS
}
)
cc_library
(
ps_service
SRCS graph_brpc_server.cc
brpc_ps_server.cc
server.cc
graph_brpc_client.cc
brpc_ps_client.cc
ps_local_client.cc
coordinator_client.cc
ps_client.cc
communicator/communicator.cc
ps_service/service.cc
ps_service/graph_py_service.cc
DEPS eigen3
table
brpc_utils
simple_threadpool
scope
math_function
selected_rows_functor
${
RPC_DEPS
}
)
cc_library
(
heter_client
SRCS heter_client.cc
DEPS brpc_utils
${
COMMON_DEPS
}
${
RPC_DEPS
}
)
cc_library
(
heter_server
SRCS heter_server.cc
DEPS heter_client brpc_utils
${
COMMON_DEPS
}
${
RPC_DEPS
}
)
paddle/fluid/distributed/ps/service/README.md
0 → 100644
View file @
de2e6515
# 目录说明
*
PSServer
*
PSClient
*
PsService
*
Communicator
*
MessageBusFramework
*
*
.proto
paddle/fluid/distributed/ps/service/brpc_ps_client.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include <memory>
#include <sstream>
#include <string>
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/split.h"
static
const
int
max_port
=
65535
;
namespace
paddle
{
namespace
framework
{
class
Scope
;
class
Variable
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
distributed
{
DEFINE_int32
(
pserver_push_dense_merge_limit
,
12
,
"limit max push_dense local merge requests"
);
DEFINE_int32
(
pserver_push_sparse_merge_limit
,
12
,
"limit max push_sparse local merge requests"
);
DEFINE_int32
(
pserver_pull_dense_limit
,
12
,
"limit max push_sparse local merge requests"
);
DEFINE_int32
(
pserver_async_push_dense_interval_ms
,
10
,
"async push_dense to server interval"
);
DEFINE_int32
(
pserver_async_push_sparse_interval_ms
,
10
,
"async push_sparse to server interval"
);
DEFINE_bool
(
pserver_scale_gradient_by_merge
,
false
,
"scale dense gradient when merged"
);
DEFINE_int32
(
pserver_communicate_compress_type
,
0
,
"none:0 snappy:1 gzip:2 zlib:3 lz4:4"
);
DEFINE_int32
(
pserver_max_async_call_num
,
13
,
"max task num in async_call_server"
);
DEFINE_int32
(
pserver_timeout_ms
,
500000
,
"pserver request server timeout_ms"
);
DEFINE_int32
(
pserver_connect_timeout_ms
,
10000
,
"pserver connect server timeout_ms"
);
DEFINE_int32
(
pserver_sparse_merge_thread
,
1
,
"pserver sparse merge thread num"
);
DEFINE_int32
(
pserver_sparse_table_shard_num
,
1000
,
"sparse table shard for save & load"
);
inline
size_t
get_sparse_shard
(
uint32_t
shard_num
,
uint32_t
server_num
,
uint64_t
key
)
{
size_t
remind
=
shard_num
%
server_num
;
size_t
local_shard_num
=
remind
==
0
?
shard_num
/
server_num
:
shard_num
/
server_num
+
1
;
return
(
key
%
shard_num
)
/
local_shard_num
;
}
void
DownpourPsClientService
::
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
int
ret
=
_client
->
HandleClient2ClientMsg
(
request
->
cmd_id
(),
request
->
client_id
(),
request
->
data
());
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
if
(
ret
!=
0
)
{
response
->
set_err_code
(
-
1
);
response
->
set_err_msg
(
"handle_client2client_msg failed"
);
}
}
// 启动 client 端 RpcService 用于数据互发等操作
int32_t
BrpcPsClient
::
StartClientService
()
{
if
(
_service
.
Configure
(
this
,
_client_id
)
!=
0
)
{
LOG
(
ERROR
)
<<
"service initialize failed, service_name:DownpourPsClientService"
;
return
-
1
;
}
_server
.
AddService
(
&
_service
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
);
brpc
::
ServerOptions
options
;
int
start_port
=
8500
;
options
.
num_threads
=
24
;
if
(
_server
.
Start
(
butil
::
my_ip_cstr
(),
brpc
::
PortRange
(
start_port
,
max_port
),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPsServer start failed"
;
return
-
1
;
}
_server_started
=
true
;
_env
->
RegistePsClient
(
butil
::
my_ip_cstr
(),
_server
.
listen_address
().
port
,
_client_id
);
VLOG
(
0
)
<<
"BrpcPsClient Service addr: "
<<
butil
::
my_ip_cstr
()
<<
", "
<<
_server
.
listen_address
().
port
<<
", "
<<
_client_id
;
return
0
;
}
// 启动 FlClientService,用户接收 coordinator 数据
int32_t
BrpcPsClient
::
StartFlClientService
(
const
std
::
string
&
self_endpoint
)
{
_fl_server
.
AddService
(
&
_service
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
);
brpc
::
ServerOptions
options
;
if
(
self_endpoint
.
empty
())
{
LOG
(
ERROR
)
<<
"fl-ps > fl client endpoint not set"
;
return
-
1
;
}
if
(
_fl_server
.
Start
(
self_endpoint
.
c_str
(),
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"fl-ps > StartFlClientService failed. Try again."
;
auto
ip_port
=
paddle
::
string
::
Split
(
self_endpoint
,
':'
);
std
::
string
ip
=
ip_port
[
0
];
int
port
=
std
::
stoi
(
ip_port
[
1
]);
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
_fl_server
.
Start
(
int_ip_port
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"fl-ps > StartFlClientService failed, ip_port= "
<<
int_ip_port
;
return
-
1
;
}
}
else
{
VLOG
(
0
)
<<
"fl-ps > StartFlClientService succeed! listen on "
<<
self_endpoint
;
}
return
0
;
}
int32_t
BrpcPsClient
::
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
)
{
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
pserver_timeout_ms
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
pserver_connect_timeout_ms
;
options
.
max_retry
=
max_retry
;
std
::
vector
<
PSHost
>
client_list
=
_env
->
GetPsClients
();
VLOG
(
1
)
<<
"BrpcPsClient::create_c2c_connection client_list size: "
<<
client_list
.
size
();
for
(
auto
cc
:
client_list
)
{
VLOG
(
1
)
<<
"BrpcPsClient::create_c2c_connection client_list: "
<<
cc
.
ToString
();
}
_client_channels
.
resize
(
client_list
.
size
());
std
::
ostringstream
os
;
std
::
string
server_ip_port
;
for
(
size_t
i
=
0
;
i
<
client_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
client_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
client_list
[
i
].
port
));
_client_channels
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
_client_channels
[
i
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
))
{
VLOG
(
0
)
<<
"BrpcPSClient connect to Client:"
<<
server_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
client_list
[
i
].
ip
,
client_list
[
i
].
port
);
if
(
_client_channels
[
i
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPSClient connect to Client:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
os
<<
server_ip_port
<<
","
;
}
LOG
(
INFO
)
<<
"Client connect success:"
<<
os
.
str
();
return
0
;
}
int32_t
BrpcPsClient
::
InitializeFlWorker
(
const
std
::
string
&
self_endpoint
)
{
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
FLAGS_pserver_timeout_ms
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
paddle
::
distributed
::
FLAGS_pserver_connect_timeout_ms
;
options
.
max_retry
=
3
;
// 获取 coordinator 列表,并连接
std
::
string
coordinator_ip_port
;
std
::
vector
<
PSHost
>
coordinator_list
=
_env
->
GetCoordinators
();
_coordinator_channels
.
resize
(
coordinator_list
.
size
());
for
(
size_t
i
=
0
;
i
<
coordinator_list
.
size
();
++
i
)
{
coordinator_ip_port
.
assign
(
coordinator_list
[
i
].
ip
.
c_str
());
coordinator_ip_port
.
append
(
":"
);
coordinator_ip_port
.
append
(
std
::
to_string
(
coordinator_list
[
i
].
port
));
VLOG
(
0
)
<<
"fl-ps > BrpcFlclient connetcting to coordinator: "
<<
coordinator_ip_port
;
for
(
size_t
j
=
0
;
j
<
_coordinator_channels
[
i
].
size
();
++
j
)
{
_coordinator_channels
[
i
][
j
].
reset
(
new
brpc
::
Channel
());
if
(
_coordinator_channels
[
i
][
j
]
->
Init
(
coordinator_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"fl-ps > BrpcFlclient connect to coordinator:"
<<
coordinator_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
coordinator_list
[
i
].
ip
,
coordinator_list
[
i
].
port
);
if
(
_coordinator_channels
[
i
][
j
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"fl-ps > BrpcFlclient connect to coordinator:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
}
}
StartFlClientService
(
self_endpoint
);
VLOG
(
0
)
<<
"fl-ps > InitializeFlWorker finished!"
;
return
0
;
}
void
BrpcPsClient
::
PushFLClientInfoSync
(
const
std
::
string
&
fl_client_info
)
{
size_t
request_call_num
=
_coordinator_channels
.
size
();
FlClientBrpcClosure
*
closure
=
new
FlClientBrpcClosure
(
request_call_num
,
[
request_call_num
](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
FlClientBrpcClosure
*>
(
done
);
int
ret
=
0
;
for
(
size_t
i
=
0
;
i
<
request_call_num
;
i
++
)
{
if
(
closure
->
check_response
(
i
,
PUSH_FL_CLIENT_INFO_SYNC
)
!=
0
)
{
LOG
(
ERROR
)
<<
"fl-ps > PushFLClientInfoSync response from "
"coordinator is failed"
;
ret
=
-
1
;
return
;
}
else
{
VLOG
(
0
)
<<
"fl-ps > rpc service call cost time: "
<<
(
closure
->
cntl
(
i
)
->
latency_us
()
/
1000
)
<<
" ms"
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int32_t
>
fut
=
promise
->
get_future
();
closure
->
add_promise
(
promise
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PUSH_FL_CLIENT_INFO_SYNC
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
closure
->
request
(
i
)
->
set_str_params
(
fl_client_info
);
brpc
::
Channel
*
rpc_channel
=
_coordinator_channels
[
0
][
0
].
get
();
if
(
rpc_channel
==
nullptr
)
{
LOG
(
ERROR
)
<<
"_coordinator_channels is null"
;
return
;
}
PsService_Stub
rpc_stub
(
rpc_channel
);
// CoordinatorService
rpc_stub
.
FLService
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
fut
.
wait
();
}
VLOG
(
0
)
<<
"fl-ps > PushFLClientInfoSync finished, client id: "
<<
_client_id
;
return
;
}
std
::
string
BrpcPsClient
::
PullFlStrategy
()
{
while
(
!
_service
.
_is_fl_strategy_ready
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
VLOG
(
0
)
<<
"fl-ps > waiting for fl strategy returned from coordinator"
;
}
_service
.
_is_fl_strategy_ready
=
false
;
// only support single thread, no need for multi-threads
return
_service
.
_fl_strategy
;
}
int32_t
BrpcPsClient
::
Initialize
()
{
_async_call_num
=
0
;
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
FLAGS_pserver_timeout_ms
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
FLAGS_pserver_connect_timeout_ms
;
options
.
max_retry
=
3
;
std
::
ostringstream
os
;
std
::
string
server_ip_port
;
std
::
string
client_ip
(
butil
::
my_ip_cstr
());
// 获取server列表,并连接
std
::
vector
<
PSHost
>
server_list
=
_env
->
GetPsServers
();
_server_channels
.
resize
(
server_list
.
size
());
for
(
size_t
i
=
0
;
i
<
server_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
server_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
server_list
[
i
].
port
));
for
(
size_t
j
=
0
;
j
<
_server_channels
[
i
].
size
();
++
j
)
{
_server_channels
[
i
][
j
].
reset
(
new
brpc
::
Channel
());
if
(
_server_channels
[
i
][
j
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"BrpcPSclient connect to Server:"
<<
server_ip_port
<<
" Failed! Try again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
server_list
[
i
].
ip
,
server_list
[
i
].
port
);
if
(
_server_channels
[
i
][
j
]
->
Init
(
int_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPSclient connect to Server:"
<<
int_ip_port
<<
" Failed!"
;
return
-
1
;
}
}
}
os
<<
server_ip_port
<<
","
;
}
// 启动client探听接口, 并相互建立连接
StartClientService
();
// 异步push 请求队列初始化
const
auto
&
worker_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
int
i
=
0
;
i
<
worker_param
.
downpour_table_param_size
();
++
i
)
{
auto
type
=
worker_param
.
downpour_table_param
(
i
).
type
();
auto
table_id
=
worker_param
.
downpour_table_param
(
i
).
table_id
();
if
(
type
==
PS_DENSE_TABLE
)
{
_push_dense_task_queue_map
[
table_id
]
=
paddle
::
framework
::
MakeChannel
<
DenseAsyncTask
*>
();
}
if
(
type
==
PS_SPARSE_TABLE
)
{
_push_sparse_task_queue_map
[
table_id
]
=
paddle
::
framework
::
MakeChannel
<
SparseAsyncTask
*>
();
_push_sparse_merge_count_map
[
table_id
]
=
0
;
}
}
auto
&
profiler
=
CostProfiler
::
instance
();
profiler
.
register_profiler
(
"pserver_client_pull_dense"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse_param"
);
profiler
.
register_profiler
(
"pserver_client_pull_sparse_local"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse_parse"
);
profiler
.
register_profiler
(
"client_push_sparse_put"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse_merge"
);
profiler
.
register_profiler
(
"pserver_client_push_sparse_rpc"
);
profiler
.
register_profiler
(
"pserver_client_push_dense"
);
profiler
.
register_profiler
(
"pserver_client_push_dense_parse"
);
profiler
.
register_profiler
(
"push_dense_put"
);
profiler
.
register_profiler
(
"pserver_client_push_dense_merge"
);
profiler
.
register_profiler
(
"pserver_client_push_dense_rpc"
);
profiler
.
register_profiler
(
"pserver_client_push_dense_send"
);
_running
=
true
;
_flushing
=
false
;
// 启动异步push线程
_async_push_sparse_thread
=
std
::
thread
(
std
::
bind
(
&
BrpcPsClient
::
PushSparseTaskConsume
,
this
));
// _async_push_sparse_thread.detach();
_async_push_dense_thread
=
std
::
thread
(
std
::
bind
(
&
BrpcPsClient
::
PushDenseTaskConsume
,
this
));
// for debug
// _print_thread =
// std::thread(std::bind(&BrpcPsClient::PrintQueueSizeThread, this));
return
0
;
}
int
DownpourBrpcClosure
::
check_response
(
size_t
request_idx
,
int
cmd_id
)
{
if
(
_cntls
[
request_idx
]
->
Failed
())
{
LOG
(
ERROR
)
<<
"resquest cmd_id:"
<<
cmd_id
<<
" failed, "
"err:"
<<
_cntls
[
request_idx
]
->
ErrorText
();
return
-
1
;
}
if
(
_responses
[
request_idx
].
err_code
()
!=
0
)
{
LOG
(
ERROR
)
<<
"response ret bad, server_idx:"
<<
request_idx
<<
"cmd_id:"
<<
cmd_id
<<
" err_code:"
<<
_responses
[
request_idx
].
err_code
()
<<
" err_msg:"
<<
_responses
[
request_idx
].
err_msg
();
return
-
1
;
}
return
0
;
}
int
DownpourBrpcClosure
::
check_save_response
(
size_t
request_idx
,
int
cmd_id
)
{
int32_t
feasign_size
=
0
;
if
(
_cntls
[
request_idx
]
->
Failed
())
{
LOG
(
ERROR
)
<<
"resquest cmd_id:"
<<
cmd_id
<<
" failed, "
"err:"
<<
_cntls
[
request_idx
]
->
ErrorText
();
return
-
1
;
}
feasign_size
=
_responses
[
request_idx
].
err_code
();
if
(
feasign_size
<
0
)
{
LOG
(
ERROR
)
<<
"response ret bad, server_idx:"
<<
request_idx
<<
"cmd_id:"
<<
cmd_id
<<
" err_code:"
<<
_responses
[
request_idx
].
err_code
()
<<
" err_msg:"
<<
_responses
[
request_idx
].
err_msg
();
return
-
1
;
}
return
feasign_size
;
}
std
::
string
DownpourBrpcClosure
::
get_response
(
size_t
request_idx
,
int
cmd_id
)
{
std
::
string
data
=
_responses
[
request_idx
].
data
();
return
data
;
}
int
FlClientBrpcClosure
::
check_response
(
size_t
request_idx
,
int
cmd_id
)
{
if
(
_cntls
[
request_idx
]
->
Failed
())
{
LOG
(
ERROR
)
<<
"resquest cmd_id:"
<<
cmd_id
<<
" failed, "
"err:"
<<
_cntls
[
request_idx
]
->
ErrorText
();
return
-
1
;
}
if
(
_responses
[
request_idx
].
err_code
()
!=
0
)
{
LOG
(
ERROR
)
<<
"response ret bad, server_idx:"
<<
request_idx
<<
"cmd_id:"
<<
cmd_id
<<
" err_code:"
<<
_responses
[
request_idx
].
err_code
()
<<
" err_msg:"
<<
_responses
[
request_idx
].
err_msg
();
return
-
1
;
}
return
0
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PrintTableStat
(
uint32_t
table_id
)
{
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
table_id
](
void
*
done
)
{
int
ret
=
0
;
uint64_t
feasign_size
=
0
;
uint64_t
mf_size
=
0
;
paddle
::
framework
::
BinaryArchive
ar
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PRINT_TABLE_STAT
)
!=
0
)
{
ret
=
-
1
;
break
;
}
std
::
string
resp
=
closure
->
get_response
(
i
,
PS_PRINT_TABLE_STAT
);
ar
.
SetReadBuffer
(
const_cast
<
char
*>
(
resp
.
c_str
()),
resp
.
length
(),
nullptr
);
feasign_size
+=
ar
.
Get
<
uint64_t
>
();
mf_size
+=
ar
.
Get
<
uint64_t
>
();
}
closure
->
set_promise_value
(
ret
);
std
::
cout
<<
"table id: "
<<
table_id
<<
", feasign size: "
<<
feasign_size
<<
", mf size: "
<<
mf_size
<<
std
::
endl
;
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PRINT_TABLE_STAT
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_timeout_ms
(
10800000
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
SendCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>
&
params
)
{
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
cmd_id
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
cmd_id
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
cmd_id
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
for
(
const
auto
&
param
:
params
)
{
closure
->
request
(
i
)
->
add_params
(
param
);
}
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_timeout_ms
(
10800000
*
2
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
SendSaveCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>
&
params
)
{
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
cmd_id
](
void
*
done
)
{
int
ret
=
0
;
uint32_t
feasign_size
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_save_response
(
i
,
cmd_id
)
<
0
)
{
ret
=
-
1
;
break
;
}
feasign_size
+=
closure
->
check_save_response
(
i
,
cmd_id
);
}
if
(
ret
==
0
)
{
closure
->
set_promise_value
(
feasign_size
);
}
else
{
closure
->
set_promise_value
(
ret
);
}
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
cmd_id
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
for
(
const
auto
&
param
:
params
)
{
closure
->
request
(
i
)
->
add_params
(
param
);
}
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_timeout_ms
(
10800000
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
{
return
SendCmd
(
table_id
,
PS_SHRINK_TABLE
,
{
threshold
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
return
SendCmd
(
-
1
,
PS_LOAD_ALL_TABLE
,
{
epoch
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
return
SendCmd
(
table_id
,
PS_LOAD_ONE_TABLE
,
{
epoch
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
VLOG
(
1
)
<<
"BrpcPsClient::save path "
<<
epoch
;
return
SendSaveCmd
(
-
1
,
PS_SAVE_ALL_TABLE
,
{
epoch
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
{
VLOG
(
1
)
<<
"BrpcPsClient::save one table path "
<<
epoch
<<
" table_id "
<<
table_id
;
return
SendSaveCmd
(
table_id
,
PS_SAVE_ONE_TABLE
,
{
epoch
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
CacheShuffle
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
{
VLOG
(
1
)
<<
"BrpcPsClient send cmd for cache shuffle"
;
return
SendSaveCmd
(
table_id
,
PS_CACHE_SHUFFLE
,
{
path
,
mode
,
cache_threshold
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
CacheShuffleMultiTable
(
std
::
vector
<
int
>
tables
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
{
VLOG
(
1
)
<<
"BrpcPsClient send cmd for cache shuffle multi table one path"
;
std
::
vector
<
std
::
string
>
param
;
param
.
push_back
(
path
);
param
.
push_back
(
mode
);
param
.
push_back
(
cache_threshold
);
for
(
size_t
i
=
0
;
i
<
tables
.
size
();
i
++
)
{
param
.
push_back
(
std
::
to_string
(
tables
[
i
]));
}
return
SendSaveCmd
(
0
,
PS_CACHE_SHUFFLE
,
param
);
}
std
::
future
<
int32_t
>
BrpcPsClient
::
SaveCache
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
)
{
return
SendSaveCmd
(
table_id
,
PS_SAVE_ONE_CACHE_TABLE
,
{
path
,
mode
});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
GetCacheThreshold
(
uint32_t
table_id
,
double
&
cache_threshold
)
{
int
cmd_id
=
PS_GET_CACHE_THRESHOLD
;
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
cmd_id
,
&
cache_threshold
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
std
::
vector
<
double
>
cache_thresholds
(
request_call_num
,
0
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
cmd_id
)
!=
0
)
{
ret
=
-
1
;
break
;
}
std
::
string
cur_res
=
closure
->
get_response
(
i
,
cmd_id
);
cache_thresholds
[
i
]
=
std
::
stod
(
cur_res
);
}
double
sum_threshold
=
0.0
;
int
count
=
0
;
for
(
auto
t
:
cache_thresholds
)
{
if
(
t
>=
0
)
{
sum_threshold
+=
t
;
++
count
;
}
}
if
(
count
==
0
)
{
cache_threshold
=
0
;
}
else
{
cache_threshold
=
sum_threshold
/
count
;
}
VLOG
(
1
)
<<
"client get cache threshold: "
<<
cache_threshold
;
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
cmd_id
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_timeout_ms
(
10800000
);
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Clear
()
{
return
SendCmd
(
-
1
,
PS_CLEAR_ALL_TABLE
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Clear
(
uint32_t
table_id
)
{
return
SendCmd
(
table_id
,
PS_CLEAR_ONE_TABLE
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Revert
()
{
return
SendCmd
(
-
1
,
PS_REVERT
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
CheckSavePrePatchDone
()
{
return
SendCmd
(
-
1
,
PS_CHECK_SAVE_PRE_PATCH_DONE
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Flush
()
{
VLOG
(
0
)
<<
"BrpcPsClient::flush begin"
;
_flushing
=
true
;
std
::
promise
<
int
>
promise
;
std
::
future
<
int32_t
>
fut
=
promise
.
get_future
();
do
{
VLOG
(
3
)
<<
"wait _async_call_num:"
<<
_async_call_num
;
usleep
(
100000
);
// sleep 100ms wait async end
}
while
(
_async_call_num
>
0
);
VLOG
(
1
)
<<
"flush _async_call_num = 0"
;
promise
.
set_value
(
0
);
_flushing
=
false
;
VLOG
(
0
)
<<
"BrpcPsClient::flush done"
;
PrintQueueSize
();
return
fut
;
}
void
BrpcPsClient
::
PrintQueueSize
()
{
for
(
auto
&
push_sparse_task_itr
:
_push_sparse_task_queue_map
)
{
auto
table_id
=
push_sparse_task_itr
.
first
;
auto
queue_size
=
push_sparse_task_itr
.
second
->
Size
();
VLOG
(
0
)
<<
"BrpcPsClient::PrintQueueSize: table "
<<
table_id
<<
" size: "
<<
queue_size
;
}
for
(
auto
&
task_queue_itr
:
_push_dense_task_queue_map
)
{
auto
table_id
=
task_queue_itr
.
first
;
auto
queue_size
=
task_queue_itr
.
second
->
Size
();
VLOG
(
0
)
<<
"BrpcPsClient::PrintQueueSize: table "
<<
table_id
<<
" size: "
<<
queue_size
;
}
}
void
BrpcPsClient
::
PrintQueueSizeThread
()
{
while
(
_running
)
{
usleep
(
1000000
*
60
*
2
);
PrintQueueSize
();
}
}
void
BrpcPsClient
::
FinalizeWorker
()
{
Flush
();
VLOG
(
0
)
<<
"BrpcPsClient::FinalizeWorker begin join thread"
;
_running
=
false
;
_async_push_dense_thread
.
join
();
_async_push_sparse_thread
.
join
();
// _print_thread.join();
VLOG
(
0
)
<<
"BrpcPsClient::FinalizeWorker begin join server"
;
_server
.
Stop
(
1000
);
_server
.
Join
();
_server_started
=
false
;
VLOG
(
0
)
<<
"BrpcPsClient::FinalizeWorker done"
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
StopServer
()
{
return
SendCmd
(
-
1
,
PS_STOP_SERVER
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
StartProfiler
()
{
return
SendCmd
(
-
1
,
PS_START_PROFILER
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
StopProfiler
()
{
return
SendCmd
(
-
1
,
PS_STOP_PROFILER
,
{});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
Barrier
(
size_t
table_id
,
uint32_t
barrier_type
)
{
return
SendCmd
(
table_id
,
PS_BARRIER
,
{
std
::
to_string
(
barrier_type
)});
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PullGeoParam
(
size_t
table_id
,
std
::
vector
<
float
>
*
values
,
std
::
vector
<
uint64_t
>
*
keys
,
int
pserver_idx
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
keys
,
values
,
accessor
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
uint32_t
shard_nums
;
if
(
closure
->
check_response
(
0
,
PS_PULL_GEO_PARAM
)
!=
0
)
{
ret
=
-
1
;
}
auto
&
res_io_buffer
=
closure
->
cntl
(
0
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
&
shard_nums
),
sizeof
(
uint32_t
));
keys
->
resize
(
shard_nums
);
values
->
resize
(
shard_nums
*
accessor
->
GetAccessorInfo
().
update_dim
);
io_buffer_itr
.
copy_and_forward
((
void
*
)(
keys
->
data
()),
// NOLINT
sizeof
(
uint64_t
)
*
shard_nums
);
io_buffer_itr
.
copy_and_forward
(
(
void
*
)(
values
->
data
()),
// NOLINT
shard_nums
*
accessor
->
GetAccessorInfo
().
update_size
);
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
closure
->
request
(
0
)
->
set_cmd_id
(
PS_PULL_GEO_PARAM
);
closure
->
request
(
0
)
->
set_table_id
(
table_id
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
PsService_Stub
rpc_stub
(
GetCmdChannel
(
pserver_idx
));
closure
->
cntl
(
0
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
// for GEO
std
::
future
<
int32_t
>
BrpcPsClient
::
PushSparseParam
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
// 发送RPC请求
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
size_t
request_call_num
=
_server_channels
.
size
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
ids
;
std
::
vector
<
std
::
vector
<
const
float
*>>
value_ptrs
;
ids
.
resize
(
request_call_num
);
value_ptrs
.
resize
(
request_call_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
pserver_idx
=
keys
[
i
]
%
request_call_num
;
ids
[
pserver_idx
].
push_back
(
keys
[
i
]);
value_ptrs
[
pserver_idx
].
push_back
(
update_values
[
i
]);
}
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
auto
kvs
=
ids
[
shard_idx
];
auto
value_ptr
=
value_ptrs
[
shard_idx
];
size_t
kv_size
=
kvs
.
size
();
uint32_t
value_size
=
accessor
->
GetAccessorInfo
().
update_size
;
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
shard_idx
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_PARAM
);
push_request
->
set_table_id
(
table_id
);
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
((
char
*
)
&
kv_size
,
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
value_size
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
kvs
.
data
(),
kv_size
*
sizeof
(
uint64_t
));
push_data_ptr
+=
kv_size
*
sizeof
(
uint64_t
);
for
(
size_t
i
=
0
;
i
<
kv_size
;
++
i
)
{
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
value_size
);
push_data_ptr
+=
value_size
;
}
PsService_Stub
rpc_stub
(
GetSparseChannel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
(
brpc
::
CompressType
)
FLAGS_pserver_communicate_compress_type
);
rpc_stub
.
service
(
closure
->
cntl
(
shard_idx
),
closure
->
request
(
shard_idx
),
closure
->
response
(
shard_idx
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_dense"
);
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
fea_dim
=
accessor
->
GetAccessorInfo
().
fea_dim
;
size_t
request_call_num
=
_server_channels
.
size
();
uint32_t
num_per_shard
=
DenseDimPerShard
(
fea_dim
,
request_call_num
);
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
,
num_per_shard
,
regions
,
region_num
,
accessor
](
void
*
done
)
{
int
ret
=
0
;
size_t
region_idx
=
0
;
// 当前填充的region偏移
size_t
region_data_idx
=
0
;
// 当前填充的region内data偏移
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
size_t
shard_data_size
=
num_per_shard
*
accessor
->
GetAccessorInfo
().
select_size
;
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PULL_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
auto
&
res_io_buffer
=
closure
->
cntl
(
i
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
size_t
shard_buffer_remain
=
res_io_buffer
.
size
();
if
(
shard_buffer_remain
!=
shard_data_size
)
{
LOG
(
ERROR
)
<<
"expect res_size:"
<<
shard_data_size
<<
", but size:"
<<
shard_buffer_remain
<<
", ignore this response"
;
ret
=
-
1
;
break
;
}
while
(
shard_buffer_remain
>
0
&&
region_idx
<
region_num
)
{
auto
&
region
=
regions
[
region_idx
];
if
(
region
.
size
-
region_data_idx
>=
shard_buffer_remain
)
{
// region待填充空间 >= 分片buffer数据, 直接拷贝置入
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
region
.
data
+
region_data_idx
),
shard_buffer_remain
);
region_data_idx
+=
shard_buffer_remain
;
shard_buffer_remain
=
0
;
}
else
if
(
region
.
size
-
region_data_idx
==
0
)
{
// region填满,切换到下一个region
++
region_idx
;
region_data_idx
=
0
;
}
else
{
// region不足以容纳所有数据,则能放多少 拷贝多少
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
region
.
data
+
region_data_idx
),
region
.
size
-
region_data_idx
);
shard_buffer_remain
-=
(
region
.
size
-
region_data_idx
);
++
region_idx
;
region_data_idx
=
0
;
}
}
}
closure
->
set_promise_value
(
ret
);
});
closure
->
add_timer
(
timer
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PULL_DENSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
closure
->
request
(
i
)
->
add_params
((
char
*
)
&
num_per_shard
,
// NOLINT
sizeof
(
num_per_shard
));
PsService_Stub
rpc_stub
(
GetDenseChannel
(
i
));
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
accessor_info
=
accessor
->
GetAccessorInfo
();
size_t
request_call_num
=
_server_channels
.
size
();
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std
::
vector
<
std
::
vector
<
Region
>>
regions_partition
(
request_call_num
);
uint32_t
num_per_shard
=
DenseDimPerShard
(
accessor_info
.
fea_dim
,
request_call_num
);
size_t
shard_data_size
=
num_per_shard
*
accessor_info
.
update_size
;
size_t
current_region_idx
=
0
;
size_t
current_region_data_idx
=
0
;
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
size_t
shard_data_remain_size
=
shard_data_size
;
while
(
shard_data_remain_size
>
0
&&
current_region_idx
<
region_num
)
{
const
auto
&
region
=
regions
[
current_region_idx
];
size_t
region_remain_size
=
region
.
size
-
current_region_data_idx
;
if
(
shard_data_remain_size
>=
region_remain_size
)
{
regions_partition
[
i
].
push_back
(
Region
(
region
.
data
+
current_region_data_idx
,
region_remain_size
));
++
current_region_idx
;
current_region_data_idx
=
0
;
shard_data_remain_size
-=
region_remain_size
;
}
else
{
regions_partition
[
i
].
push_back
(
Region
(
region
.
data
+
current_region_data_idx
,
shard_data_remain_size
));
current_region_data_idx
+=
shard_data_remain_size
;
shard_data_remain_size
=
0
;
}
}
}
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_DENSE_PARAM
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
});
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
static
const
int
REGION_ASSIGN_BUFFER_SIZE
=
1024
*
10
;
static
char
region_assign_buffer
[
REGION_ASSIGN_BUFFER_SIZE
];
// 用于数据补齐
// 开始多shard并行拷贝&请求
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_DENSE_PARAM
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
auto
&
request_buffer
=
closure
->
cntl
(
i
)
->
request_attachment
();
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
num_per_shard
),
sizeof
(
uint32_t
));
auto
&
region_list
=
regions_partition
[
i
];
size_t
fill_remain_size
=
shard_data_size
;
for
(
auto
&
region
:
region_list
)
{
fill_remain_size
-=
region
.
size
;
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
region
.
data
),
region
.
size
);
}
// 保证各分片数据对齐
while
(
fill_remain_size
>
0
)
{
size_t
fill_num
=
fill_remain_size
>
REGION_ASSIGN_BUFFER_SIZE
?
REGION_ASSIGN_BUFFER_SIZE
:
fill_remain_size
;
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
region_assign_buffer
),
fill_num
);
fill_remain_size
-=
fill_num
;
}
PsService_Stub
rpc_stub
(
GetDenseChannel
(
i
));
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
// 发送RPC请求
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
size_t
request_call_num
=
_server_channels
.
size
();
std
::
vector
<
std
::
vector
<
uint64_t
>>
ids
;
std
::
vector
<
std
::
vector
<
const
float
*>>
value_ptrs
;
ids
.
resize
(
request_call_num
);
value_ptrs
.
resize
(
request_call_num
);
const
auto
&
server_param
=
_config
.
server_param
().
downpour_server_param
();
uint64_t
shard_num
=
FLAGS_pserver_sparse_table_shard_num
;
for
(
int
i
=
0
;
i
<
server_param
.
downpour_table_param_size
();
++
i
)
{
const
auto
&
table_param
=
server_param
.
downpour_table_param
(
i
);
if
(
table_param
.
table_id
()
==
table_id
)
{
shard_num
=
table_param
.
shard_num
();
break
;
}
}
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
pserver_idx
=
get_sparse_shard
(
shard_num
,
request_call_num
,
keys
[
i
]);
ids
[
pserver_idx
].
push_back
(
keys
[
i
]);
value_ptrs
[
pserver_idx
].
push_back
(
update_values
[
i
]);
}
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
auto
kvs
=
ids
[
shard_idx
];
auto
value_ptr
=
value_ptrs
[
shard_idx
];
size_t
kv_size
=
kvs
.
size
();
uint32_t
value_size
=
accessor
->
GetAccessorInfo
().
update_size
;
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
shard_idx
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_TABLE
);
push_request
->
set_table_id
(
table_id
);
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
((
char
*
)
&
kv_size
,
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
push_data
->
resize
(
kv_size
*
(
sizeof
(
uint64_t
)
+
value_size
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
kvs
.
data
(),
kv_size
*
sizeof
(
uint64_t
));
push_data_ptr
+=
kv_size
*
sizeof
(
uint64_t
);
for
(
size_t
i
=
0
;
i
<
kv_size
;
++
i
)
{
memcpy
(
push_data_ptr
,
value_ptr
[
i
],
value_size
);
push_data_ptr
+=
value_size
;
}
PsService_Stub
rpc_stub
(
GetSparseChannel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
(
brpc
::
CompressType
)
FLAGS_pserver_communicate_compress_type
);
rpc_stub
.
service
(
closure
->
cntl
(
shard_idx
),
closure
->
request
(
shard_idx
),
closure
->
response
(
shard_idx
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
done
)
{
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
auto
*
accessor
=
GetTableAccessor
(
table_id
);
uint32_t
num_per_shard
=
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
request_call_num
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_DENSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
auto
*
push_data
=
closure
->
request
(
i
)
->
mutable_data
();
push_data
->
clear
();
push_data
->
resize
(
sizeof
(
uint32_t
)
+
num_per_shard
*
sizeof
(
float
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
&
num_per_shard
,
sizeof
(
uint32_t
));
memcpy
(
push_data_ptr
+
sizeof
(
uint32_t
),
total_send_data
+
i
*
num_per_shard
,
num_per_shard
*
sizeof
(
float
));
// closure->cntl(i)->set_request_compress_type(
// (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub
rpc_stub
(
GetDenseChannel
(
i
));
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushGlobalStep
(
int
table_id
,
int64_t
*
total_send_data
,
void
*
done
)
{
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_GLOBAL_STEP
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
auto
*
push_data
=
closure
->
request
(
i
)
->
mutable_data
();
push_data
->
clear
();
int32_t
num_per_shard
=
1
;
push_data
->
resize
(
sizeof
(
uint32_t
)
+
num_per_shard
*
sizeof
(
int64_t
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
&
num_per_shard
,
sizeof
(
uint32_t
));
memcpy
(
push_data_ptr
+
sizeof
(
uint32_t
),
total_send_data
,
num_per_shard
*
sizeof
(
int64_t
));
PsService_Stub
rpc_stub
(
GetDenseChannel
(
i
));
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PullSparse
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_sparse"
);
auto
local_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_sparse_local"
);
size_t
request_call_num
=
_server_channels
.
size
();
auto
shard_sorted_kvs
=
std
::
make_shared
<
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
*>>>>
();
shard_sorted_kvs
->
resize
(
request_call_num
);
const
auto
&
server_param
=
_config
.
server_param
().
downpour_server_param
();
uint64_t
shard_num
=
FLAGS_pserver_sparse_table_shard_num
;
for
(
int
i
=
0
;
i
<
server_param
.
downpour_table_param_size
();
++
i
)
{
const
auto
&
table_param
=
server_param
.
downpour_table_param
(
i
);
if
(
table_param
.
table_id
()
==
table_id
)
{
shard_num
=
table_param
.
shard_num
();
break
;
}
}
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
shard_id
=
get_sparse_shard
(
shard_num
,
request_call_num
,
keys
[
i
]);
shard_sorted_kvs
->
at
(
shard_id
).
push_back
({
keys
[
i
],
select_values
[
i
]});
}
auto
*
accessor
=
GetTableAccessor
(
table_id
);
size_t
value_size
=
accessor
->
GetAccessorInfo
().
select_size
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
shard_sorted_kvs
->
size
();
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PULL_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
auto
&
request_kvs
=
shard_sorted_kvs
->
at
(
i
);
auto
&
res_io_buffer
=
closure
->
cntl
(
i
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
uint64_t
last_key
=
UINT64_MAX
;
float
*
last_value_data
=
NULL
;
for
(
size_t
kv_idx
=
0
;
kv_idx
<
request_kvs
.
size
();
++
kv_idx
)
{
auto
*
kv_pair
=
&
(
request_kvs
[
kv_idx
]);
if
(
kv_pair
->
first
==
last_key
)
{
memcpy
(
reinterpret_cast
<
void
*>
(
kv_pair
->
second
),
reinterpret_cast
<
void
*>
(
last_value_data
),
value_size
);
}
else
{
last_key
=
kv_pair
->
first
;
last_value_data
=
kv_pair
->
second
;
if
(
value_size
!=
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
last_value_data
),
value_size
))
{
LOG
(
WARNING
)
<<
"res data is lack or not in format"
;
ret
=
-
1
;
break
;
}
}
}
}
closure
->
set_promise_value
(
ret
);
});
closure
->
add_timer
(
timer
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
auto
&
sorted_kvs
=
shard_sorted_kvs
->
at
(
i
);
std
::
sort
(
sorted_kvs
.
begin
(),
sorted_kvs
.
end
(),
[](
const
std
::
pair
<
uint64_t
,
float
*>
&
k1
,
const
std
::
pair
<
uint64_t
,
float
*>
&
k2
)
{
return
k1
.
first
<
k2
.
first
;
});
uint64_t
last_key
=
UINT64_MAX
;
uint32_t
kv_request_count
=
0
;
size_t
sorted_kv_size
=
sorted_kvs
.
size
();
auto
&
request_buffer
=
closure
->
cntl
(
i
)
->
request_attachment
();
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
is_training
),
sizeof
(
bool
));
std
::
vector
<
uint32_t
>
keys_counter
;
keys_counter
.
reserve
(
sorted_kv_size
);
for
(
size_t
kv_idx
=
0
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
++
kv_request_count
;
uint32_t
keys
=
1
;
last_key
=
sorted_kvs
[
kv_idx
].
first
;
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
last_key
),
sizeof
(
uint64_t
));
while
(
kv_idx
<
sorted_kv_size
-
1
&&
last_key
==
sorted_kvs
[
kv_idx
+
1
].
first
)
{
++
kv_idx
;
++
keys
;
}
keys_counter
.
push_back
(
keys
);
}
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
keys_counter
.
data
()),
sizeof
(
uint32_t
)
*
keys_counter
.
size
());
if
(
kv_request_count
==
0
)
{
closure
->
Run
();
}
else
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PULL_SPARSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
closure
->
request
(
i
)
->
add_params
((
char
*
)
&
kv_request_count
,
// NOLINT
sizeof
(
uint32_t
));
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
}
return
fut
;
}
// for GEO
std
::
future
<
int32_t
>
BrpcPsClient
::
PullSparseParam
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
)
{
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_pull_sparse_param"
);
size_t
request_call_num
=
_server_channels
.
size
();
auto
shard_sorted_kvs
=
std
::
make_shared
<
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
float
*>>>>
();
shard_sorted_kvs
->
resize
(
request_call_num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
shard_id
=
keys
[
i
]
%
request_call_num
;
shard_sorted_kvs
->
at
(
shard_id
).
push_back
({
keys
[
i
],
select_values
[
i
]});
}
auto
*
accessor
=
GetTableAccessor
(
table_id
);
size_t
value_size
=
accessor
->
GetAccessorInfo
().
select_size
;
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
shard_sorted_kvs
,
value_size
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
shard_sorted_kvs
->
size
();
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PULL_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
auto
&
request_kvs
=
shard_sorted_kvs
->
at
(
i
);
auto
&
res_io_buffer
=
closure
->
cntl
(
i
)
->
response_attachment
();
butil
::
IOBufBytesIterator
io_buffer_itr
(
res_io_buffer
);
uint64_t
last_key
=
UINT64_MAX
;
float
*
last_value_data
=
NULL
;
// can remove sort&unique
for
(
size_t
kv_idx
=
0
;
kv_idx
<
request_kvs
.
size
();
++
kv_idx
)
{
auto
*
kv_pair
=
&
(
request_kvs
[
kv_idx
]);
if
(
kv_pair
->
first
==
last_key
)
{
memcpy
(
reinterpret_cast
<
void
*>
(
kv_pair
->
second
),
reinterpret_cast
<
void
*>
(
last_value_data
),
value_size
);
}
else
{
last_key
=
kv_pair
->
first
;
last_value_data
=
kv_pair
->
second
;
if
(
value_size
!=
io_buffer_itr
.
copy_and_forward
(
reinterpret_cast
<
void
*>
(
last_value_data
),
value_size
))
{
LOG
(
WARNING
)
<<
"res data is lack or not in format"
;
ret
=
-
1
;
break
;
}
}
}
}
closure
->
set_promise_value
(
ret
);
});
closure
->
add_timer
(
timer
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
auto
&
sorted_kvs
=
shard_sorted_kvs
->
at
(
i
);
std
::
sort
(
sorted_kvs
.
begin
(),
sorted_kvs
.
end
(),
[](
const
std
::
pair
<
uint64_t
,
float
*>
&
k1
,
const
std
::
pair
<
uint64_t
,
float
*>
&
k2
)
{
return
k1
.
first
<
k2
.
first
;
});
uint64_t
last_key
=
UINT64_MAX
;
uint32_t
kv_request_count
=
0
;
size_t
sorted_kv_size
=
sorted_kvs
.
size
();
auto
&
request_buffer
=
closure
->
cntl
(
i
)
->
request_attachment
();
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
is_training
),
sizeof
(
bool
));
std
::
vector
<
uint32_t
>
keys_counter
;
keys_counter
.
reserve
(
sorted_kv_size
);
for
(
size_t
kv_idx
=
0
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
++
kv_request_count
;
uint32_t
keys
=
1
;
last_key
=
sorted_kvs
[
kv_idx
].
first
;
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
&
last_key
),
sizeof
(
uint64_t
));
while
(
kv_idx
<
sorted_kv_size
-
1
&&
last_key
==
sorted_kvs
[
kv_idx
+
1
].
first
)
{
++
kv_idx
;
++
keys
;
}
keys_counter
.
push_back
(
keys
);
}
request_buffer
.
append
(
reinterpret_cast
<
void
*>
(
keys_counter
.
data
()),
sizeof
(
uint32_t
)
*
keys_counter
.
size
());
if
(
kv_request_count
==
0
)
{
closure
->
Run
();
}
else
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PULL_SPARSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
table_id
);
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
closure
->
request
(
i
)
->
add_params
((
char
*
)
&
kv_request_count
,
// NOLINT
sizeof
(
uint32_t
));
PsService_Stub
rpc_stub
(
GetCmdChannel
(
i
));
closure
->
cntl
(
i
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
}
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
SendClient2ClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
to_client_id
>=
0
&&
static_cast
<
size_t
>
(
to_client_id
)
>=
_client_channels
.
size
())
{
VLOG
(
0
)
<<
"to_client_id is out of range clients, which size is "
<<
_client_channels
.
size
();
promise
->
set_value
(
-
1
);
return
fut
;
}
auto
*
closure
=
new
DownpourBrpcClosure
(
1
,
[
msg_type
](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
int32_t
ret
=
closure
->
check_response
(
0
,
msg_type
+
1000
);
closure
->
set_promise_value
(
ret
);
});
closure
->
add_promise
(
promise
);
closure
->
request
(
0
)
->
set_cmd_id
(
msg_type
);
closure
->
request
(
0
)
->
set_client_id
(
_client_id
);
closure
->
request
(
0
)
->
set_data
(
msg
);
PsService_Stub
rpc_stub
(
_client_channels
[
to_client_id
].
get
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushSparseRawGradientPartial
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
size_t
value_size
=
accessor
->
GetAccessorInfo
().
update_size
;
DownpourBrpcClosure
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
closure
->
add_promise
(
promise
);
std
::
future
<
int
>
fut
=
promise
->
get_future
();
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
0
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_TABLE
);
push_request
->
set_table_id
(
table_id
);
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
((
char
*
)
&
num
,
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
push_data
->
resize
(
num
*
(
sizeof
(
uint64_t
)
+
value_size
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
keys
,
num
*
sizeof
(
uint64_t
));
push_data_ptr
+=
num
*
sizeof
(
uint64_t
);
for
(
uint32_t
i
=
0
;
i
<
num
;
++
i
)
{
memcpy
(
push_data_ptr
,
update_values
[
i
],
value_size
);
push_data_ptr
+=
value_size
;
}
PsService_Stub
rpc_stub
(
GetSparseChannel
(
pserver_idx
));
closure
->
cntl
(
0
)
->
set_request_compress_type
(
(
brpc
::
CompressType
)
FLAGS_pserver_communicate_compress_type
);
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
int32_t
BrpcPsClient
::
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
)
{
// get var information
std
::
string
var_name
=
""
;
int64_t
var_num
=
0
;
int64_t
var_shape
=
0
;
std
::
string
table_class
;
const
auto
&
worker_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
int
i
=
0
;
i
<
worker_param
.
downpour_table_param_size
();
++
i
)
{
if
(
worker_param
.
downpour_table_param
(
i
).
table_id
()
==
table_id
)
{
var_name
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_name
();
var_num
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_num
();
var_shape
=
worker_param
.
downpour_table_param
(
i
).
common
().
table_dim
();
table_class
=
worker_param
.
downpour_table_param
(
i
).
table_class
();
break
;
}
}
PADDLE_ENFORCE_NE
(
var_name
,
""
,
platform
::
errors
::
InvalidArgument
(
"Cannot find table id %d to save variables."
,
table_id
));
std
::
string
var_store
=
string
::
Sprintf
(
"%s"
,
path
);
MkDirRecursively
(
var_store
.
c_str
());
// pull sparse from server
std
::
vector
<
float
>
save_huge_vec
(
var_num
*
var_shape
);
std
::
vector
<
uint64_t
>
save_key
(
var_num
);
std
::
vector
<
float
*>
save_vec
;
for
(
size_t
i
=
0
;
i
<
save_key
.
size
();
++
i
)
{
save_key
[
i
]
=
i
;
save_vec
.
push_back
(
save_huge_vec
.
data
()
+
i
*
var_shape
);
}
VLOG
(
2
)
<<
"RecvAndSaveTable: table_class: "
<<
table_class
;
// TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its
// RecvAndSaveTable
if
(
table_class
==
"MemorySparseGeoTable"
)
{
auto
status
=
PullSparseParam
(
reinterpret_cast
<
float
**>
(
save_vec
.
data
()),
table_id
,
save_key
.
data
(),
save_key
.
size
(),
true
);
status
.
wait
();
}
else
{
auto
status
=
PullSparse
(
reinterpret_cast
<
float
**>
(
save_vec
.
data
()),
table_id
,
save_key
.
data
(),
save_key
.
size
(),
true
);
status
.
wait
();
}
// create lod tensor
std
::
shared_ptr
<
framework
::
Scope
>
scope
;
scope
.
reset
(
new
framework
::
Scope
());
auto
place
=
platform
::
CPUPlace
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
Variable
*
var
=
scope
->
Var
(
var_name
);
framework
::
LoDTensor
*
var_tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
std
::
vector
<
int64_t
>
vec_dim
=
{
var_num
,
var_shape
};
var_tensor
->
Resize
(
phi
::
make_ddim
(
vec_dim
));
// copy and save
float
*
tensor_data
=
var_tensor
->
mutable_data
<
float
>
(
place
);
memcpy
(
tensor_data
,
save_huge_vec
.
data
(),
var_num
*
var_shape
*
sizeof
(
float
));
std
::
string
file_name
=
string
::
Sprintf
(
"%s/%s"
,
var_store
,
var_name
);
std
::
ofstream
fout
(
file_name
,
std
::
ios
::
binary
);
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fout
),
true
,
platform
::
errors
::
Unavailable
(
"Cannot open %s to save variables."
,
file_name
));
framework
::
SerializeToStream
(
fout
,
*
var_tensor
,
dev_ctx
);
fout
.
close
();
return
0
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
)
{
auto
push_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_sparse"
);
CostTimer
parse_timer
(
"pserver_client_push_sparse_parse"
);
int
push_sparse_async_num
=
_push_sparse_task_queue_map
[
table_id
]
->
Size
();
while
(
push_sparse_async_num
>
FLAGS_pserver_max_async_call_num
)
{
// LOG(INFO) << "PushSparse Waiting for async_call_num comsume,
// task_num:"
// << push_sparse_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
usleep
(
5000
);
// 5ms
push_sparse_async_num
=
_push_sparse_task_queue_map
[
table_id
]
->
Size
();
}
auto
put_timer
=
std
::
make_shared
<
CostTimer
>
(
"client_push_sparse_put"
);
thread_local
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
const
float
*>>>
shard_sorted_kv_list
;
auto
*
accessor
=
GetTableAccessor
(
table_id
);
size_t
request_call_num
=
_server_channels
.
size
();
shard_sorted_kv_list
.
resize
(
request_call_num
);
for
(
auto
&
x
:
shard_sorted_kv_list
)
{
x
.
clear
();
}
const
auto
&
server_param
=
_config
.
server_param
().
downpour_server_param
();
uint64_t
shard_num
=
FLAGS_pserver_sparse_table_shard_num
;
for
(
int
i
=
0
;
i
<
server_param
.
downpour_table_param_size
();
++
i
)
{
const
auto
&
table_param
=
server_param
.
downpour_table_param
(
i
);
if
(
table_param
.
table_id
()
==
table_id
)
{
shard_num
=
table_param
.
shard_num
();
break
;
}
}
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
size_t
shard_id
=
get_sparse_shard
(
shard_num
,
request_call_num
,
keys
[
i
]);
shard_sorted_kv_list
[
shard_id
].
push_back
({
keys
[
i
],
update_values
[
i
]});
}
auto
sparse_task_data
=
_sparse_task_pool
.
get
();
sparse_task_data
->
shared_data
.
resize
(
request_call_num
);
auto
async_task
=
new
SparseAsyncTask
(
sparse_task_data
,
table_id
,
push_timer
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
auto
&
sorted_kv_list
=
shard_sorted_kv_list
[
i
];
size_t
sorted_kv_size
=
sorted_kv_list
.
size
();
auto
&
shard_kv_data
=
async_task
->
data
()
->
shared_data
[
i
];
shard_kv_data
.
key_list
.
resize
(
sorted_kv_size
);
shard_kv_data
.
value_list
.
resize
(
sorted_kv_size
);
if
(
sorted_kv_size
==
0
)
{
shard_kv_data
.
kv_num
=
0
;
continue
;
}
uint32_t
value_size
=
accessor
->
GetAccessorInfo
().
update_size
;
for
(
size_t
kv_idx
=
0
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
shard_kv_data
.
key_list
[
kv_idx
]
=
sorted_kv_list
[
kv_idx
].
first
;
shard_kv_data
.
value_list
[
kv_idx
].
assign
(
(
const
char
*
)
sorted_kv_list
[
kv_idx
].
second
,
value_size
);
}
shard_kv_data
.
kv_num
=
sorted_kv_size
;
}
std
::
future
<
int
>
fut
=
async_task
->
get_future
();
_push_sparse_task_queue_map
[
table_id
]
->
Put
(
std
::
move
(
async_task
));
return
fut
;
}
void
BrpcPsClient
::
PushSparseTaskConsume
()
{
uint64_t
merge_size
=
FLAGS_pserver_push_sparse_merge_limit
;
std
::
vector
<
std
::
shared_ptr
<
SparseAsyncTask
>>
task_list
;
size_t
request_call_num
=
_server_channels
.
size
();
::
ThreadPool
async_push_sparse_shard_threads
(
FLAGS_pserver_sparse_merge_thread
);
while
(
_running
)
{
auto
async_start_time_ms
=
butil
::
gettimeofday_ms
();
// 所有sparseTable的pushTask 进行处理
for
(
auto
&
push_sparse_task_itr
:
_push_sparse_task_queue_map
)
{
auto
table_id
=
push_sparse_task_itr
.
first
;
auto
*
accessor
=
GetTableAccessor
(
table_id
);
auto
&
task_queue
=
push_sparse_task_itr
.
second
;
auto
queue_size
=
task_queue
->
Size
();
if
(
queue_size
==
0
)
{
continue
;
}
if
(
merge_size
>
0
&&
(
queue_size
<=
1
&&
_flushing
==
false
))
{
continue
;
}
++
_async_call_num
;
int
merge_count
=
0
;
for
(
size_t
i
=
0
;
i
<
task_list
.
size
();
++
i
)
{
if
(
task_list
[
i
]
->
data
())
{
_sparse_task_pool
.
push
(
task_list
[
i
]
->
data
());
}
}
auto
sparse_task_data
=
_sparse_task_pool
.
get
();
task_list
.
clear
();
int
cur_meger_size
=
task_queue
->
Size
();
// task_list[0] 为一个空SparseAsyncTask, 分shard异步merge结果存入此结构。
sparse_task_data
->
shared_data
.
resize
(
request_call_num
);
auto
push_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_sparse"
);
auto
async_task
=
new
SparseAsyncTask
(
sparse_task_data
,
table_id
,
push_timer
);
task_list
.
reserve
(
cur_meger_size
+
1
);
task_list
.
push_back
(
std
::
move
(
std
::
shared_ptr
<
SparseAsyncTask
>
(
async_task
)));
while
(
!
task_queue
->
Empty
()
&&
merge_count
<
cur_meger_size
)
{
++
merge_count
;
SparseAsyncTask
*
task
;
task_queue
->
Get
(
task
);
task_list
.
push_back
(
std
::
shared_ptr
<
SparseAsyncTask
>
(
task
));
}
_push_sparse_merge_count_map
[
table_id
]
+=
merge_count
;
// 达到或大于 merge_size发送, 发送过程中
std
::
vector
<
int
>
request_kv_num
(
request_call_num
,
0
);
if
(
_push_sparse_merge_count_map
[
table_id
]
>=
merge_size
||
_flushing
==
true
)
{
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
--
_async_call_num
;
});
for_each
(
task_list
.
begin
()
+
1
,
task_list
.
end
(),
[
&
request_kv_num
,
request_call_num
,
closure
](
std
::
shared_ptr
<
SparseAsyncTask
>
&
task
)
{
closure
->
add_timer
(
task
->
timer
());
closure
->
add_promise
(
task
->
promise
());
});
CostTimer
merge_timer
(
"pserver_client_push_sparse_merge"
);
auto
rpc_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_sparse_rpc"
);
closure
->
add_timer
(
rpc_timer
);
std
::
vector
<
std
::
future
<
int
>>
merge_status
(
request_call_num
);
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
merge_status
[
shard_idx
]
=
async_push_sparse_shard_threads
.
enqueue
(
std
::
bind
(
&
BrpcPsClient
::
PushSparseAsyncShardPush
,
this
,
task_list
,
request_kv_num
,
table_id
,
shard_idx
,
closure
,
accessor
));
}
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
merge_status
[
shard_idx
].
wait
();
}
merge_status
.
clear
();
std
::
vector
<
std
::
future
<
int
>>
().
swap
(
merge_status
);
_push_sparse_merge_count_map
[
table_id
]
=
0
;
}
else
{
// 未达到阈值 只做多路归并
std
::
vector
<
std
::
future
<
int
>>
merge_status
(
request_call_num
);
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
merge_status
[
shard_idx
]
=
async_push_sparse_shard_threads
.
enqueue
(
std
::
bind
(
&
BrpcPsClient
::
PushSparseAsyncShardMerge
,
this
,
task_list
,
request_kv_num
,
table_id
,
shard_idx
,
accessor
));
}
for
(
size_t
shard_idx
=
0
;
shard_idx
<
request_call_num
;
++
shard_idx
)
{
merge_status
[
shard_idx
].
wait
();
}
// meger到task_list[0]
auto
async_task
=
new
SparseAsyncTask
(
*
(
task_list
[
0
].
get
()));
task_queue
->
Put
(
std
::
move
(
async_task
));
--
_async_call_num
;
merge_status
.
clear
();
std
::
vector
<
std
::
future
<
int
>>
().
swap
(
merge_status
);
}
}
auto
wait_ms
=
FLAGS_pserver_async_push_sparse_interval_ms
-
(
butil
::
gettimeofday_ms
()
-
async_start_time_ms
);
if
(
wait_ms
>
0
)
{
usleep
(
wait_ms
*
1000
);
}
}
}
void
sparse_local_merge
(
ValueAccessor
*
accessor
,
float
*
merge_data
,
const
float
*
another_data
)
{
size_t
col_num
=
accessor
->
GetAccessorInfo
().
update_dim
;
float
*
merge_data_shell
[
col_num
];
const
float
*
another_data_shell
[
col_num
];
for
(
size_t
i
=
0
;
i
<
col_num
;
++
i
)
{
merge_data_shell
[
i
]
=
merge_data
+
i
;
another_data_shell
[
i
]
=
another_data
+
i
;
}
accessor
->
Merge
(
merge_data_shell
,
another_data_shell
,
1
);
}
int
BrpcPsClient
::
PushSparseAsyncShardMerge
(
std
::
vector
<
std
::
shared_ptr
<
SparseAsyncTask
>>
&
task_list
,
std
::
vector
<
int
>
&
request_kv_num
,
int
table_id
,
int
shard_idx
,
ValueAccessor
*
accessor
)
{
size_t
merged_kv_count
=
0
;
uint32_t
value_size
=
accessor
->
GetAccessorInfo
().
update_size
;
thread_local
std
::
vector
<
std
::
pair
<
uint64_t
,
const
float
*>>
sorted_kv_list
;
sorted_kv_list
.
clear
();
for
(
size_t
i
=
1
;
i
<
task_list
.
size
();
++
i
)
{
size_t
kv_num
=
task_list
[
i
]
->
data
()
->
shared_data
[
shard_idx
].
kv_num
;
auto
&
key_list
=
task_list
[
i
]
->
data
()
->
shared_data
[
shard_idx
].
key_list
;
auto
&
value_list
=
task_list
[
i
]
->
data
()
->
shared_data
[
shard_idx
].
value_list
;
for
(
size_t
j
=
0
;
j
<
kv_num
;
++
j
)
{
if
(
value_list
[
j
].
size
()
<
value_size
)
{
LOG
(
WARNING
)
<<
"value_list["
<<
j
<<
"]: "
<<
value_list
[
j
].
c_str
()
<<
"is invalid."
;
continue
;
}
char
*
task_data_ptr
=
const_cast
<
char
*>
(
value_list
[
j
].
data
());
sorted_kv_list
.
push_back
(
{
key_list
[
j
],
reinterpret_cast
<
float
*>
(
task_data_ptr
)});
}
}
// 按key排序&去重
std
::
sort
(
sorted_kv_list
.
begin
(),
sorted_kv_list
.
end
(),
[](
const
std
::
pair
<
uint64_t
,
const
float
*>
&
k1
,
const
std
::
pair
<
uint64_t
,
const
float
*>
&
k2
)
{
return
k1
.
first
<
k2
.
first
;
});
auto
&
async_task
=
task_list
[
0
];
size_t
sorted_kv_size
=
sorted_kv_list
.
size
();
auto
&
shard_kv_data
=
async_task
->
data
()
->
shared_data
[
shard_idx
];
shard_kv_data
.
key_list
.
resize
(
sorted_kv_size
);
shard_kv_data
.
value_list
.
resize
(
sorted_kv_size
);
// 将去重后数据写入分shard包
if
(
sorted_kv_size
==
0
)
{
shard_kv_data
.
kv_num
=
0
;
return
0
;
}
else
if
(
sorted_kv_size
==
1
)
{
shard_kv_data
.
kv_num
=
1
;
shard_kv_data
.
key_list
[
0
]
=
sorted_kv_list
[
0
].
first
;
shard_kv_data
.
value_list
[
0
].
assign
((
const
char
*
)(
sorted_kv_list
[
0
].
second
),
value_size
);
return
0
;
}
// 去重 本地merge
uint64_t
last_key
=
sorted_kv_list
[
0
].
first
;
const
float
*
last_value_data
=
sorted_kv_list
[
0
].
second
;
float
*
last_merge_data
=
NULL
;
std
::
shared_ptr
<
char
>
merger_buffer
(
new
char
[
value_size
],
array_deleter
<
char
>
());
for
(
size_t
kv_idx
=
1
;
kv_idx
<
sorted_kv_size
;
++
kv_idx
)
{
while
(
kv_idx
<
sorted_kv_size
&&
last_key
==
sorted_kv_list
[
kv_idx
].
first
)
{
if
(
last_merge_data
==
NULL
)
{
last_merge_data
=
reinterpret_cast
<
float
*>
(
merger_buffer
.
get
());
memcpy
(
last_merge_data
,
last_value_data
,
value_size
);
}
sparse_local_merge
(
accessor
,
last_merge_data
,
sorted_kv_list
[
kv_idx
].
second
);
++
kv_idx
;
}
if
(
last_merge_data
!=
NULL
)
{
shard_kv_data
.
value_list
[
merged_kv_count
].
assign
(
(
const
char
*
)
last_merge_data
,
value_size
);
last_merge_data
=
NULL
;
}
else
{
shard_kv_data
.
value_list
[
merged_kv_count
].
assign
(
(
const
char
*
)
sorted_kv_list
[
kv_idx
-
1
].
second
,
value_size
);
}
shard_kv_data
.
key_list
[
merged_kv_count
++
]
=
last_key
;
if
(
kv_idx
<
sorted_kv_size
)
{
last_key
=
sorted_kv_list
[
kv_idx
].
first
;
last_value_data
=
sorted_kv_list
[
kv_idx
].
second
;
}
if
(
kv_idx
==
sorted_kv_size
-
1
)
{
shard_kv_data
.
value_list
[
merged_kv_count
].
assign
(
(
const
char
*
)
last_value_data
,
value_size
);
shard_kv_data
.
key_list
[
merged_kv_count
++
]
=
last_key
;
}
}
shard_kv_data
.
kv_num
=
merged_kv_count
;
return
0
;
}
int
BrpcPsClient
::
PushSparseAsyncShardPush
(
std
::
vector
<
std
::
shared_ptr
<
SparseAsyncTask
>>
&
task_list
,
std
::
vector
<
int
>
&
request_kv_num
,
int
table_id
,
int
shard_idx
,
DownpourBrpcClosure
*
closure
,
ValueAccessor
*
accessor
)
{
PushSparseAsyncShardMerge
(
task_list
,
request_kv_num
,
table_id
,
shard_idx
,
accessor
);
size_t
merged_kv_count
=
task_list
[
0
]
->
data
()
->
shared_data
[
shard_idx
].
kv_num
;
auto
&
merged_key_list
=
task_list
[
0
]
->
data
()
->
shared_data
[
shard_idx
].
key_list
;
auto
&
merged_value_list
=
task_list
[
0
]
->
data
()
->
shared_data
[
shard_idx
].
value_list
;
// 发送RPC请求
auto
*
push_request
=
closure
->
request
(
shard_idx
);
push_request
->
set_cmd_id
(
PS_PUSH_SPARSE_TABLE
);
push_request
->
set_table_id
(
table_id
);
push_request
->
set_client_id
(
_client_id
);
push_request
->
add_params
(
reinterpret_cast
<
char
*>
(
&
merged_kv_count
),
sizeof
(
uint32_t
));
// NOLINT
auto
*
push_data
=
push_request
->
mutable_data
();
int
update_size
=
accessor
->
GetAccessorInfo
().
update_size
;
push_data
->
resize
(
merged_kv_count
*
(
sizeof
(
uint64_t
)
+
update_size
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
merged_key_list
.
data
(),
merged_kv_count
*
sizeof
(
uint64_t
));
push_data_ptr
+=
merged_kv_count
*
sizeof
(
uint64_t
);
for
(
size_t
i
=
0
;
i
<
merged_kv_count
;
++
i
)
{
const
char
*
task_data_ptr
=
merged_value_list
[
i
].
data
();
memcpy
(
push_data_ptr
,
(
float
*
)(
task_data_ptr
),
// NOLINT
update_size
);
push_data_ptr
+=
update_size
;
}
PsService_Stub
rpc_stub
(
GetSparseChannel
(
shard_idx
));
closure
->
cntl
(
shard_idx
)
->
set_request_compress_type
(
(
brpc
::
CompressType
)
FLAGS_pserver_communicate_compress_type
);
rpc_stub
.
service
(
closure
->
cntl
(
shard_idx
),
closure
->
request
(
shard_idx
),
closure
->
response
(
shard_idx
),
closure
);
_push_sparse_merge_count_map
[
table_id
]
=
0
;
return
0
;
}
std
::
future
<
int32_t
>
BrpcPsClient
::
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
)
{
auto
*
accessor
=
GetTableAccessor
(
table_id
);
int
fea_dim
=
accessor
->
GetAccessorInfo
().
fea_dim
;
int
update_dim
=
accessor
->
GetAccessorInfo
().
update_dim
;
auto
push_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense"
);
auto
parse_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_parse"
);
int
push_dense_async_num
=
_push_dense_task_queue_map
[
table_id
]
->
Size
();
while
(
push_dense_async_num
>
FLAGS_pserver_max_async_call_num
)
{
// LOG(INFO) << "PushDense Waiting for async_call_num comsume,
// task_num:"
// << push_dense_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
usleep
(
5000
);
// 5ms
push_dense_async_num
=
_push_dense_task_queue_map
[
table_id
]
->
Size
();
}
auto
push_dense_timer
=
std
::
make_shared
<
CostTimer
>
(
"push_dense_put"
);
// auto dense_data = _dense_matrix_obj_pool.get();
auto
dense_data
=
std
::
make_shared
<
std
::
vector
<
float
>>
();
auto
async_task
=
new
DenseAsyncTask
(
dense_data
,
table_id
,
push_timer
);
size_t
request_call_num
=
_server_channels
.
size
();
uint32_t
num_per_shard
=
DenseDimPerShard
(
fea_dim
,
request_call_num
);
// 将region数据拷贝到转置矩阵中
async_task
->
data
()
->
resize
(
num_per_shard
*
request_call_num
*
update_dim
);
float
*
data
=
async_task
->
data
()
->
data
();
size_t
data_size
=
async_task
->
data
()
->
size
();
uint32_t
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
region_num
;
++
i
)
{
uint32_t
data_num
=
regions
[
i
].
size
/
sizeof
(
float
);
CHECK
(
pos
+
data_num
<=
data_size
)
<<
"invalid dense size, cur pos["
<<
pos
<<
"]"
<<
" data_num["
<<
data_num
<<
"] size["
<<
data_size
<<
"]"
;
const
float
*
region_data
=
(
const
float
*
)(
regions
[
i
].
data
);
memcpy
(
data
+
pos
,
region_data
,
regions
[
i
].
size
);
pos
+=
data_num
;
}
std
::
future
<
int
>
fut
=
async_task
->
get_future
();
_push_dense_task_queue_map
[
table_id
]
->
Put
(
std
::
move
(
async_task
));
return
fut
;
}
void
BrpcPsClient
::
PushDenseTaskConsume
()
{
uint64_t
merge_size
=
FLAGS_pserver_push_dense_merge_limit
;
static
bool
scale_gradient
=
FLAGS_pserver_scale_gradient_by_merge
;
::
ThreadPool
async_merge_dense_threads
(
10
);
while
(
_running
)
{
auto
async_start_time_ms
=
butil
::
gettimeofday_ms
();
for
(
auto
&
task_queue_itr
:
_push_dense_task_queue_map
)
{
auto
&
task_queue
=
task_queue_itr
.
second
;
auto
queue_size
=
task_queue
->
Size
();
if
(
queue_size
==
0
)
{
continue
;
}
if
(
queue_size
<=
merge_size
&&
_flushing
==
false
)
{
continue
;
}
++
_async_call_num
;
DenseAsyncTask
*
task
;
task_queue
->
Get
(
task
);
auto
*
accessor
=
GetTableAccessor
(
task
->
table_id
());
// 设置请求回调
size_t
request_call_num
=
_server_channels
.
size
();
DownpourBrpcClosure
*
closure
=
new
DownpourBrpcClosure
(
request_call_num
,
[
this
,
request_call_num
](
void
*
done
)
{
int
ret
=
0
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
PS_PUSH_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
}
closure
->
set_promise_value
(
ret
);
--
_async_call_num
;
});
auto
&
total_send_data_vec
=
*
(
task
->
data
());
float
*
total_send_data
=
reinterpret_cast
<
float
*>
(
total_send_data_vec
.
data
());
size_t
total_send_data_size
=
total_send_data_vec
.
size
();
{
CostTimer
merge_timer
(
"pserver_client_push_dense_merge"
);
uint32_t
merge_count
=
0
;
std
::
vector
<
std
::
future
<
int
>>
merge_status
(
merge_size
);
while
(
!
task_queue
->
Empty
()
&&
merge_count
<
merge_size
)
{
auto
*
async_task
=
new
DenseAsyncTask
();
task_queue
->
Get
(
async_task
);
closure
->
add_timer
(
async_task
->
timer
());
closure
->
add_promise
(
async_task
->
promise
());
merge_status
[
merge_count
]
=
async_merge_dense_threads
.
enqueue
([
closure
,
accessor
,
&
total_send_data
,
total_send_data_size
,
async_task
]()
->
int
{
auto
&
tmp_task_vec
=
*
(
async_task
->
data
());
const
float
*
merge_data
=
tmp_task_vec
.
data
();
accessor
->
Merge
(
&
total_send_data
,
&
merge_data
,
total_send_data_size
);
#pragma optimize("", off)
delete
async_task
;
#pragma optimize("", on)
return
0
;
});
++
merge_count
;
}
for
(
size_t
i
=
0
;
i
<
merge_count
;
++
i
)
{
merge_status
[
i
].
wait
();
}
VLOG
(
3
)
<<
"BrpcPsClient::PushDenseTaskConsume before merge "
"total_send_data[0]"
<<
total_send_data
[
0
]
<<
" total_send_data[-2]"
<<
total_send_data
[
total_send_data_size
-
2
]
<<
total_send_data
[
0
]
<<
" total_send_data[-1]"
<<
total_send_data
[
total_send_data_size
-
1
];
if
(
scale_gradient
&&
merge_count
>
1
)
{
Eigen
::
Map
<
Eigen
::
MatrixXf
>
mat
(
total_send_data
,
1
,
total_send_data_size
);
mat
*=
(
1.0
/
(
merge_count
+
1
));
}
VLOG
(
3
)
<<
"BrpcPsClient::PushDenseTaskConsume after merge "
"total_send_data[0]"
<<
total_send_data
[
0
]
<<
" total_send_data[-2]"
<<
total_send_data
[
total_send_data_size
-
2
]
<<
" total_send_data[-1]"
<<
total_send_data
[
total_send_data_size
-
1
]
<<
" merge_count "
<<
merge_count
;
}
std
::
shared_ptr
<
DenseAsyncTask
>
task_ptr
(
task
);
PushDenseRawGradient
(
task_ptr
,
total_send_data
,
total_send_data_size
,
closure
);
}
auto
wait_ms
=
FLAGS_pserver_async_push_dense_interval_ms
-
(
butil
::
gettimeofday_ms
()
-
async_start_time_ms
);
if
(
wait_ms
>
0
)
{
usleep
(
wait_ms
*
1000
);
}
}
}
void
BrpcPsClient
::
PushDenseRawGradient
(
std
::
shared_ptr
<
DenseAsyncTask
>
&
task
,
float
*
total_send_data
,
size_t
total_send_data_size
,
DownpourBrpcClosure
*
closure
)
{
auto
*
accessor
=
GetTableAccessor
(
task
->
table_id
());
size_t
request_call_num
=
_server_channels
.
size
();
// 将数据拷贝到请求buffer区
auto
timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_rpc"
);
closure
->
add_timer
(
timer
);
uint32_t
num_per_shard
=
DenseDimPerShard
(
accessor
->
GetAccessorInfo
().
fea_dim
,
request_call_num
);
auto
send_timer
=
std
::
make_shared
<
CostTimer
>
(
"pserver_client_push_dense_send"
);
for
(
size_t
i
=
0
;
i
<
request_call_num
;
++
i
)
{
closure
->
request
(
i
)
->
set_cmd_id
(
PS_PUSH_DENSE_TABLE
);
closure
->
request
(
i
)
->
set_table_id
(
task
->
table_id
());
closure
->
request
(
i
)
->
set_client_id
(
_client_id
);
auto
*
push_data
=
closure
->
request
(
i
)
->
mutable_data
();
push_data
->
clear
();
push_data
->
resize
(
sizeof
(
uint32_t
)
+
num_per_shard
*
sizeof
(
float
));
char
*
push_data_ptr
=
const_cast
<
char
*>
(
push_data
->
data
());
memcpy
(
push_data_ptr
,
&
num_per_shard
,
sizeof
(
uint32_t
));
memcpy
(
push_data_ptr
+
sizeof
(
uint32_t
),
total_send_data
+
i
*
num_per_shard
,
num_per_shard
*
sizeof
(
float
));
closure
->
cntl
(
i
)
->
set_request_compress_type
(
(
brpc
::
CompressType
)
FLAGS_pserver_communicate_compress_type
);
PsService_Stub
rpc_stub
(
GetDenseChannel
(
i
));
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
closure
->
response
(
i
),
closure
);
}
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/brpc_ps_client.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace
brpc
{
class
Channel
;
class
Controller
;
}
// namespace brpc
namespace
google
{
namespace
protobuf
{
class
Closure
;
class
RpcController
;
}
// namespace protobuf
}
// namespace google
namespace
paddle
{
namespace
distributed
{
struct
Region
;
class
DownpourPsClientService
:
public
PsService
{
public:
DownpourPsClientService
()
{}
virtual
~
DownpourPsClientService
()
{}
virtual
int32_t
Configure
(
PSClient
*
client
,
size_t
rank_id
)
{
_client
=
client
;
_rank
=
rank_id
;
return
0
;
}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
);
virtual
void
FLService
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
CoordinatorReqMessage
*
request
,
CoordinatorResMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
size_t
client_id
=
request
->
client_id
();
CHECK
(
_client
->
_client_id
==
client_id
)
<<
"request client id not matched self"
;
_fl_strategy
=
request
->
str_params
();
_is_fl_strategy_ready
=
true
;
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
VLOG
(
0
)
<<
"fl-ps > DownpourPsClientService::FLService finished!"
;
return
;
}
public:
std
::
string
_fl_strategy
;
bool
_is_fl_strategy_ready
=
false
;
protected:
size_t
_rank
;
PSClient
*
_client
;
};
class
FlClientBrpcClosure
:
public
PSClientClosure
{
public:
FlClientBrpcClosure
(
size_t
num
,
PSClientCallBack
callback
)
:
PSClientClosure
(
callback
)
{
_waiting_num
=
num
;
_cntls
.
resize
(
num
);
_requests
.
resize
(
num
);
_responses
.
resize
(
num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
_cntls
[
i
].
reset
(
new
brpc
::
Controller
());
}
}
virtual
~
FlClientBrpcClosure
()
{}
void
Run
()
override
{
if
(
_waiting_num
.
fetch_sub
(
1
)
==
1
)
{
_callback
(
this
);
delete
this
;
}
}
CoordinatorReqMessage
*
request
(
size_t
i
)
{
return
&
_requests
[
i
];
}
CoordinatorResMessage
*
response
(
size_t
i
)
{
return
&
_responses
[
i
];
}
brpc
::
Controller
*
cntl
(
size_t
i
)
{
return
_cntls
[
i
].
get
();
}
int
check_response
(
size_t
request_idx
,
int
cmd_id
);
int
check_save_response
(
size_t
request_idx
,
int
cmd_id
);
std
::
string
get_response
(
size_t
request_idx
,
int
cmd_id
);
private:
std
::
atomic
<
int32_t
>
_waiting_num
;
std
::
vector
<
CoordinatorReqMessage
>
_requests
;
std
::
vector
<
CoordinatorResMessage
>
_responses
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Controller
>>
_cntls
;
};
class
DownpourBrpcClosure
:
public
PSClientClosure
{
public:
DownpourBrpcClosure
(
size_t
num
,
PSClientCallBack
callback
)
:
PSClientClosure
(
callback
)
{
_waiting_num
=
num
;
_cntls
.
resize
(
num
);
_requests
.
resize
(
num
);
_responses
.
resize
(
num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
_cntls
[
i
].
reset
(
new
brpc
::
Controller
());
}
}
virtual
~
DownpourBrpcClosure
()
{}
void
Run
()
override
{
if
(
_waiting_num
.
fetch_sub
(
1
)
==
1
)
{
_callback
(
this
);
delete
this
;
}
}
PsRequestMessage
*
request
(
size_t
i
)
{
return
&
_requests
[
i
];
}
PsResponseMessage
*
response
(
size_t
i
)
{
return
&
_responses
[
i
];
}
brpc
::
Controller
*
cntl
(
size_t
i
)
{
return
_cntls
[
i
].
get
();
}
int
check_response
(
size_t
request_idx
,
int
cmd_id
);
int
check_save_response
(
size_t
request_idx
,
int
cmd_id
);
std
::
string
get_response
(
size_t
request_idx
,
int
cmd_id
);
private:
std
::
atomic
<
int32_t
>
_waiting_num
;
std
::
vector
<
PsRequestMessage
>
_requests
;
std
::
vector
<
PsResponseMessage
>
_responses
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Controller
>>
_cntls
;
};
struct
SharedSparsePushData
{
SharedSparsePushData
()
{}
~
SharedSparsePushData
()
noexcept
{}
size_t
kv_num
;
std
::
vector
<
uint64_t
>
key_list
;
std
::
vector
<
std
::
string
>
value_list
;
};
struct
SparsePushTaskData
{
std
::
vector
<
SharedSparsePushData
>
shared_data
;
// sparse数据按key hash分片
};
// push sparse 对象池
struct
SparseTaskPool
{
std
::
shared_ptr
<
SparsePushTaskData
>
get
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
_mutex
);
if
(
_pool
.
empty
())
{
return
std
::
make_shared
<
SparsePushTaskData
>
();
}
else
{
auto
ret
=
_pool
.
back
();
_pool
.
pop_back
();
return
ret
;
}
}
void
push
(
std
::
shared_ptr
<
SparsePushTaskData
>
data
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
_mutex
);
_pool
.
push_back
(
std
::
move
(
data
));
}
std
::
vector
<
std
::
shared_ptr
<
SparsePushTaskData
>>
_pool
;
std
::
mutex
_mutex
;
};
template
<
class
T
>
struct
array_deleter
{
void
operator
()(
T
*&
x
)
const
{
delete
[]
x
;
}
// NOLINT
};
class
BrpcPsClient
:
public
PSClient
{
public:
BrpcPsClient
()
{}
virtual
~
BrpcPsClient
()
{
if
(
_running
)
{
Flush
();
_running
=
false
;
}
if
(
_async_push_dense_thread
.
joinable
())
{
_async_push_dense_thread
.
join
();
}
if
(
_async_push_sparse_thread
.
joinable
())
{
_async_push_sparse_thread
.
join
();
}
if
(
_server_started
)
{
_server
.
Stop
(
1000
);
_server
.
Join
();
_server_started
=
false
;
}
}
virtual
int32_t
CreateClient2ClientConnection
(
int
pserver_timeout_ms
,
int
pserver_connect_timeout_ms
,
int
max_retry
);
std
::
future
<
int32_t
>
Shrink
(
uint32_t
table_id
,
const
std
::
string
threshold
)
override
;
std
::
future
<
int32_t
>
Load
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
Load
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
Save
(
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
Save
(
uint32_t
table_id
,
const
std
::
string
&
epoch
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
Clear
()
override
;
std
::
future
<
int32_t
>
Clear
(
uint32_t
table_id
)
override
;
std
::
future
<
int32_t
>
Revert
()
override
;
std
::
future
<
int32_t
>
CheckSavePrePatchDone
()
override
;
std
::
future
<
int32_t
>
StopServer
()
override
;
std
::
future
<
int32_t
>
StartProfiler
()
override
;
std
::
future
<
int32_t
>
StopProfiler
()
override
;
void
FinalizeWorker
()
override
;
virtual
std
::
future
<
int32_t
>
PullDense
(
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
std
::
future
<
int32_t
>
PushDenseParam
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
virtual
std
::
future
<
int32_t
>
PushDense
(
const
Region
*
regions
,
size_t
region_num
,
size_t
table_id
);
void
PushDenseTaskConsume
();
virtual
std
::
future
<
int32_t
>
PullSparse
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
);
virtual
std
::
future
<
int32_t
>
PullSparseParam
(
float
**
select_values
,
size_t
table_id
,
const
uint64_t
*
keys
,
size_t
num
,
bool
is_training
);
virtual
std
::
future
<
int32_t
>
PrintTableStat
(
uint32_t
table_id
);
virtual
std
::
future
<
int32_t
>
Barrier
(
size_t
table_id
,
uint32_t
barrier_type
);
virtual
std
::
future
<
int32_t
>
PullGeoParam
(
size_t
table_id
,
std
::
vector
<
float
>
*
values
,
std
::
vector
<
uint64_t
>
*
keys
,
int
pserver_idx
);
virtual
std
::
future
<
int32_t
>
PushGlobalStep
(
int
table_id
,
int64_t
*
total_send_data
,
void
*
done
);
virtual
std
::
future
<
int32_t
>
Flush
();
std
::
future
<
int32_t
>
SendClient2ClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
override
;
// for local save sparse
virtual
int32_t
RecvAndSaveTable
(
const
uint64_t
table_id
,
const
std
::
string
&
path
);
std
::
future
<
int32_t
>
CacheShuffle
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
)
override
;
std
::
future
<
int32_t
>
CacheShuffleMultiTable
(
std
::
vector
<
int
>
tables
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
,
const
std
::
string
&
cache_threshold
);
std
::
future
<
int32_t
>
SaveCache
(
uint32_t
table_id
,
const
std
::
string
&
path
,
const
std
::
string
&
mode
)
override
;
std
::
future
<
int32_t
>
GetCacheThreshold
(
uint32_t
table_id
,
double
&
cache_threshold
)
override
;
void
PrintQueueSize
();
void
PrintQueueSizeThread
();
protected:
virtual
size_t
GetServerNums
()
{
return
_server_channels
.
size
();
}
inline
brpc
::
Channel
*
GetSparseChannel
(
size_t
server_id
)
{
return
_server_channels
[
server_id
][
0
].
get
();
}
inline
brpc
::
Channel
*
GetDenseChannel
(
size_t
server_id
)
{
return
_server_channels
[
server_id
][
1
].
get
();
}
inline
brpc
::
Channel
*
GetCmdChannel
(
size_t
server_id
)
{
return
_server_channels
[
server_id
][
2
].
get
();
}
int32_t
Initialize
()
override
;
// for fl
public:
virtual
int32_t
InitializeFlWorker
(
const
std
::
string
&
self_endpoint
);
int32_t
StartFlClientService
(
const
std
::
string
&
self_endpoint
);
virtual
void
PushFLClientInfoSync
(
const
std
::
string
&
fl_client_info
);
std
::
string
PullFlStrategy
();
// for fl
private:
inline
uint32_t
DenseDimPerShard
(
uint32_t
dense_dim_total
,
uint32_t
shard_num
)
{
return
dense_dim_total
/
shard_num
+
1
;
}
std
::
future
<
int32_t
>
SendCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>
&
param
);
std
::
future
<
int32_t
>
SendSaveCmd
(
uint32_t
table_id
,
int
cmd_id
,
const
std
::
vector
<
std
::
string
>
&
param
);
bool
_running
=
false
;
bool
_flushing
=
false
;
std
::
atomic
<
uint32_t
>
_async_call_num
;
// 异步请求计数
// 异步push dense task
std
::
thread
_async_push_dense_thread
;
typedef
AsyncRequestTask
<
std
::
shared_ptr
<
std
::
vector
<
float
>>>
DenseAsyncTask
;
std
::
unordered_map
<
uint32_t
,
paddle
::
framework
::
Channel
<
DenseAsyncTask
*>>
_push_dense_task_queue_map
;
// 异步push sparse task
std
::
thread
_async_push_sparse_thread
;
typedef
AsyncRequestTask
<
std
::
shared_ptr
<
SparsePushTaskData
>>
SparseAsyncTask
;
std
::
unordered_map
<
uint32_t
,
paddle
::
framework
::
Channel
<
SparseAsyncTask
*>>
_push_sparse_task_queue_map
;
std
::
unordered_map
<
uint32_t
,
uint32_t
>
_push_sparse_merge_count_map
;
std
::
thread
_print_thread
;
int
PushSparseAsyncShardMerge
(
std
::
vector
<
std
::
shared_ptr
<
SparseAsyncTask
>>
&
task_list
,
// NOLINT
std
::
vector
<
int
>
&
request_kv_num
,
// NOLINT
int
table_id
,
int
shard_idx
,
ValueAccessor
*
accessor
);
int
PushSparseAsyncShardPush
(
std
::
vector
<
std
::
shared_ptr
<
SparseAsyncTask
>>
&
task_list
,
// NOLINT
std
::
vector
<
int
>
&
request_kv_num
,
// NOLINT
int
table_id
,
int
shard_idx
,
DownpourBrpcClosure
*
closure
,
ValueAccessor
*
accessor
);
SparseTaskPool
_sparse_task_pool
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_client_channels
;
// client2client
std
::
vector
<
std
::
array
<
std
::
shared_ptr
<
brpc
::
Channel
>
,
3
>>
_server_channels
;
// client2server
std
::
vector
<
std
::
array
<
std
::
shared_ptr
<
brpc
::
Channel
>
,
1
>>
_coordinator_channels
;
// client2coordinator
std
::
future
<
int32_t
>
PushDenseRawGradient
(
int
table_id
,
float
*
total_send_data
,
size_t
total_send_data_size
,
void
*
done
)
override
;
std
::
future
<
int32_t
>
PushSparseRawGradient
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
override
;
std
::
future
<
int32_t
>
PushSparseRawGradientPartial
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
uint32_t
num
,
void
*
done
,
int
pserver_idx
)
override
;
std
::
future
<
int32_t
>
PushSparseParam
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
,
void
*
done
)
override
;
std
::
future
<
int32_t
>
PushSparse
(
size_t
table_id
,
const
uint64_t
*
keys
,
const
float
**
update_values
,
size_t
num
)
override
;
void
PushSparseTaskConsume
();
private:
int32_t
StartClientService
();
void
PushDenseRawGradient
(
std
::
shared_ptr
<
DenseAsyncTask
>
&
task
,
// NOLINT
float
*
total_send_data
,
size_t
total_send_data_size
,
DownpourBrpcClosure
*
closure
);
float
_mae
=
0
;
float
_mse
=
0
;
uint16_t
_push_times
=
0
;
brpc
::
Server
_server
;
brpc
::
Server
_fl_server
;
DownpourPsClientService
_service
;
bool
_server_started
=
false
;
std
::
atomic_uint
grad_num_
{
0
};
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/brpc_ps_server.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "butil/object_pool.h"
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace
google
{
namespace
protobuf
{
class
Closure
;
class
RpcController
;
}
// namespace protobuf
}
// namespace google
DEFINE_int32
(
pserver_timeout_ms_s2s
,
10000
,
"pserver request server timeout_ms"
);
DEFINE_int32
(
pserver_connect_timeout_ms_s2s
,
10000
,
"pserver connect server timeout_ms"
);
DEFINE_string
(
pserver_connection_type_s2s
,
"pooled"
,
"pserver connection_type[pooled:single]"
);
namespace
paddle
{
namespace
distributed
{
int32_t
BrpcPsServer
::
Initialize
()
{
auto
&
service_config
=
_config
.
downpour_server_param
().
service_param
();
if
(
!
service_config
.
has_service_class
())
{
LOG
(
ERROR
)
<<
"miss service_class in ServerServiceParameter"
;
return
-
1
;
}
auto
*
service
=
CREATE_PSCORE_CLASS
(
PsBaseService
,
service_config
.
service_class
());
if
(
service
==
NULL
)
{
LOG
(
ERROR
)
<<
"service is unregistered, service_name:"
<<
service_config
.
service_class
();
return
-
1
;
}
_service
.
reset
(
service
);
if
(
service
->
Configure
(
this
)
!=
0
||
service
->
Initialize
()
!=
0
)
{
LOG
(
ERROR
)
<<
"service initialize failed, service_name:"
<<
service_config
.
service_class
();
return
-
1
;
}
if
(
_server
.
AddService
(
service
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
)
!=
0
)
{
LOG
(
ERROR
)
<<
"service add to brpc failed, service:"
<<
service_config
.
service_class
();
return
-
1
;
}
return
0
;
}
uint64_t
BrpcPsServer
::
Start
(
const
std
::
string
&
ip
,
uint32_t
port
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
string
ip_port
=
ip
+
":"
+
std
::
to_string
(
port
);
VLOG
(
0
)
<<
"running server with rank id: "
<<
_rank
<<
", endpoint: "
<<
ip_port
;
brpc
::
ServerOptions
options
;
int
num_threads
=
std
::
thread
::
hardware_concurrency
();
auto
trainers
=
_environment
->
GetTrainers
();
options
.
num_threads
=
trainers
>
num_threads
?
trainers
:
num_threads
;
if
(
_server
.
Start
(
ip_port
.
c_str
(),
&
options
)
!=
0
)
{
VLOG
(
0
)
<<
"BrpcPsServer start failed, ip_port= "
<<
ip_port
<<
" , Try Again."
;
std
::
string
int_ip_port
=
GetIntTypeEndpoint
(
ip
,
port
);
if
(
_server
.
Start
(
int_ip_port
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"BrpcPsServer start failed, ip_port= "
<<
int_ip_port
;
return
0
;
}
}
_environment
->
RegistePsServer
(
ip
,
port
,
_rank
);
cv_
.
wait
(
lock
,
[
&
]
{
return
stoped_
;
});
PSHost
host
;
host
.
ip
=
ip
;
host
.
port
=
port
;
host
.
rank
=
_rank
;
return
host
.
rank
;
}
int32_t
BrpcPsServer
::
StartS2S
()
{
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
timeout_ms
=
FLAGS_pserver_timeout_ms_s2s
;
options
.
connection_type
=
FLAGS_pserver_connection_type_s2s
;
options
.
connect_timeout_ms
=
FLAGS_pserver_connect_timeout_ms_s2s
;
options
.
max_retry
=
3
;
std
::
vector
<
PSHost
>
pserver_list
=
_environment
->
GetPsServers
();
_pserver_channels
.
resize
(
pserver_list
.
size
());
VLOG
(
2
)
<<
"pserver start s2s server_list size: "
<<
_pserver_channels
.
size
();
std
::
ostringstream
os
;
std
::
string
server_ip_port
;
for
(
size_t
i
=
0
;
i
<
pserver_list
.
size
();
++
i
)
{
server_ip_port
.
assign
(
pserver_list
[
i
].
ip
.
c_str
());
server_ip_port
.
append
(
":"
);
server_ip_port
.
append
(
std
::
to_string
(
pserver_list
[
i
].
port
));
_pserver_channels
[
i
].
reset
(
new
brpc
::
Channel
());
if
(
_pserver_channels
[
i
]
->
Init
(
server_ip_port
.
c_str
(),
""
,
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"pserver connect to pserver:"
<<
server_ip_port
<<
" Failed!"
;
}
os
<<
server_ip_port
<<
","
;
}
LOG
(
INFO
)
<<
"pserver connect success: "
<<
os
.
str
();
return
0
;
}
std
::
future
<
int32_t
>
BrpcPsServer
::
SendPServer2PServerMsg
(
int
msg_type
,
int
to_pserver_id
,
const
std
::
string
&
msg
)
{
auto
promise
=
std
::
make_shared
<
std
::
promise
<
int32_t
>>
();
std
::
future
<
int
>
fut
=
promise
->
get_future
();
if
(
static_cast
<
size_t
>
(
to_pserver_id
)
>=
_pserver_channels
.
size
())
{
LOG
(
FATAL
)
<<
"to_pserver_id is out of range pservers, which size is "
<<
_pserver_channels
.
size
();
promise
->
set_value
(
-
1
);
return
fut
;
}
auto
*
closure
=
new
DownpourPServerBrpcClosure
(
1
,
[
msg_type
](
void
*
done
)
{
auto
*
closure
=
reinterpret_cast
<
DownpourPServerBrpcClosure
*>
(
done
);
int32_t
ret
=
closure
->
check_response
(
0
,
msg_type
+
1000
);
closure
->
set_promise_value
(
ret
);
});
closure
->
add_promise
(
promise
);
closure
->
request
(
0
)
->
set_cmd_id
(
101
);
closure
->
request
(
0
)
->
set_client_id
(
_rank
);
closure
->
request
(
0
)
->
set_table_id
(
0
);
closure
->
request
(
0
)
->
set_data
(
msg
);
PsService_Stub
rpc_stub
(
_pserver_channels
[
to_pserver_id
].
get
());
rpc_stub
.
service
(
closure
->
cntl
(
0
),
closure
->
request
(
0
),
closure
->
response
(
0
),
closure
);
return
fut
;
}
int32_t
BrpcPsServer
::
ReceiveFromPServer
(
int
msg_type
,
int
pserver_id
,
const
std
::
string
&
msg
)
{
if
(
msg
.
length
()
==
0
)
{
LOG
(
WARNING
)
<<
"SERVER>>RESPONSE>>msg = 0 Finish S2S Response"
;
return
0
;
}
paddle
::
framework
::
BinaryArchive
ar
;
ar
.
SetReadBuffer
(
const_cast
<
char
*>
(
msg
.
c_str
()),
msg
.
length
(),
nullptr
);
if
(
ar
.
Cursor
()
==
ar
.
Finish
())
{
LOG
(
WARNING
)
<<
"SERVER>>RESPONSE ar = 0>> Finish S2S Response"
;
return
0
;
}
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
data
;
while
(
ar
.
Cursor
()
<
ar
.
Finish
())
{
data
.
push_back
(
ar
.
Get
<
std
::
pair
<
uint64_t
,
std
::
string
>>
());
}
CHECK
(
ar
.
Cursor
()
==
ar
.
Finish
());
this
->
_shuffled_ins
->
Write
(
std
::
move
(
data
));
return
0
;
}
int32_t
BrpcPsServer
::
Port
()
{
return
_server
.
listen_address
().
port
;
}
int32_t
BrpcPsService
::
Initialize
()
{
_is_initialize_shard_info
=
false
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
BrpcPsService
::
StopServer
;
_service_handler_map
[
PS_PULL_DENSE_TABLE
]
=
&
BrpcPsService
::
PullDense
;
_service_handler_map
[
PS_PUSH_DENSE_TABLE
]
=
&
BrpcPsService
::
PushDense
;
_service_handler_map
[
PS_PULL_SPARSE_TABLE
]
=
&
BrpcPsService
::
PullSparse
;
_service_handler_map
[
PS_PUSH_SPARSE_TABLE
]
=
&
BrpcPsService
::
PushSparse
;
_service_handler_map
[
PS_SAVE_ONE_TABLE
]
=
&
BrpcPsService
::
SaveOneTable
;
_service_handler_map
[
PS_SAVE_ALL_TABLE
]
=
&
BrpcPsService
::
SaveAllTable
;
_service_handler_map
[
PS_SHRINK_TABLE
]
=
&
BrpcPsService
::
ShrinkTable
;
_service_handler_map
[
PS_LOAD_ONE_TABLE
]
=
&
BrpcPsService
::
LoadOneTable
;
_service_handler_map
[
PS_LOAD_ALL_TABLE
]
=
&
BrpcPsService
::
LoadAllTable
;
_service_handler_map
[
PS_CLEAR_ONE_TABLE
]
=
&
BrpcPsService
::
ClearOneTable
;
_service_handler_map
[
PS_CLEAR_ALL_TABLE
]
=
&
BrpcPsService
::
ClearAllTable
;
_service_handler_map
[
PS_PUSH_DENSE_PARAM
]
=
&
BrpcPsService
::
PushDenseParam
;
_service_handler_map
[
PS_PRINT_TABLE_STAT
]
=
&
BrpcPsService
::
PrintTableStat
;
_service_handler_map
[
PS_PULL_GEO_PARAM
]
=
&
BrpcPsService
::
PullGeoParam
;
_service_handler_map
[
PS_PUSH_SPARSE_PARAM
]
=
&
BrpcPsService
::
PushSparseParam
;
_service_handler_map
[
PS_BARRIER
]
=
&
BrpcPsService
::
Barrier
;
_service_handler_map
[
PS_START_PROFILER
]
=
&
BrpcPsService
::
StartProfiler
;
_service_handler_map
[
PS_STOP_PROFILER
]
=
&
BrpcPsService
::
StopProfiler
;
_service_handler_map
[
PS_PUSH_GLOBAL_STEP
]
=
&
BrpcPsService
::
PushGlobalStep
;
// for save cache
_service_handler_map
[
PS_SAVE_ONE_CACHE_TABLE
]
=
&
BrpcPsService
::
SaveCacheTable
;
_service_handler_map
[
PS_GET_CACHE_THRESHOLD
]
=
&
BrpcPsService
::
GetCacheThreshold
;
_service_handler_map
[
PS_CACHE_SHUFFLE
]
=
&
BrpcPsService
::
CacheShuffle
;
_service_handler_map
[
PS_REVERT
]
=
&
BrpcPsService
::
Revert
;
_service_handler_map
[
PS_CHECK_SAVE_PRE_PATCH_DONE
]
=
&
BrpcPsService
::
CheckSavePrePatchDone
;
auto
&
profiler
=
CostProfiler
::
instance
();
profiler
.
register_profiler
(
"pserver_server_pull_dense"
);
profiler
.
register_profiler
(
"pserver_server_push_dense"
);
profiler
.
register_profiler
(
"pserver_server_pull_sparse"
);
profiler
.
register_profiler
(
"pserver_server_push_sparse"
);
// shard初始化,server启动后才可从env获取到server_list的shard信息
InitializeShardInfo
();
return
0
;
}
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t
BrpcPsService
::
InitializeShardInfo
()
{
if
(
!
_is_initialize_shard_info
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
_initialize_shard_mutex
);
if
(
_is_initialize_shard_info
)
{
return
0
;
}
size_t
shard_num
=
_server
->
Environment
()
->
GetPsServers
().
size
();
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
itr
:
table_map
)
{
itr
.
second
->
SetShard
(
_rank
,
shard_num
);
}
_is_initialize_shard_info
=
true
;
}
return
0
;
}
void
BrpcPsService
::
service
(
google
::
protobuf
::
RpcController
*
cntl_base
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
log_label
(
"ReceiveCmd-"
);
if
(
!
request
->
has_table_id
())
{
set_response_code
(
*
response
,
-
1
,
"PsRequestMessage.tabel_id is required"
);
return
;
}
response
->
set_err_code
(
0
);
response
->
set_err_msg
(
""
);
auto
*
table
=
_server
->
GetTable
(
request
->
table_id
());
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_base
);
if
(
request
->
cmd_id
()
<
100
)
{
auto
itr
=
_service_handler_map
.
find
(
request
->
cmd_id
());
if
(
itr
==
_service_handler_map
.
end
())
{
std
::
string
err_msg
(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"
);
err_msg
.
append
(
std
::
to_string
(
request
->
cmd_id
()));
set_response_code
(
*
response
,
-
1
,
err_msg
.
c_str
());
return
;
}
serviceHandlerFunc
handler_func
=
itr
->
second
;
int
service_ret
=
(
this
->*
handler_func
)(
table
,
*
request
,
*
response
,
cntl
);
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
service_ret
);
response
->
set_err_msg
(
"server internal error"
);
}
}
else
{
int
service_ret
=
_server
->
HandlePServer2PServerMsg
(
request
->
cmd_id
(),
request
->
client_id
(),
request
->
data
());
if
(
service_ret
!=
0
)
{
response
->
set_err_code
(
-
1
);
response
->
set_err_msg
(
"handle_pserver2pserver_msg failed"
);
}
}
}
int32_t
BrpcPsService
::
PullDense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PullDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 1 for num of dense"
);
return
0
;
}
CostTimer
timer
(
"pserver_server_pull_dense"
);
uint32_t
num
=
*
(
const
uint32_t
*
)
request
.
params
(
0
).
c_str
();
auto
res_data
=
butil
::
get_object
<
std
::
vector
<
float
>>
();
res_data
->
resize
(
num
*
table
->
ValueAccesor
()
->
GetAccessorInfo
().
select_size
/
sizeof
(
float
));
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
pull_context
.
values
=
res_data
->
data
();
table_context
.
num
=
num
;
table
->
Pull
(
table_context
);
cntl
->
response_attachment
().
append
(
reinterpret_cast
<
char
*>
(
res_data
->
data
()),
res_data
->
size
()
*
sizeof
(
float
));
butil
::
return_object
(
res_data
);
return
0
;
}
int32_t
BrpcPsService
::
PushDenseParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PushDenseParam"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
thread_local
std
::
string
push_buffer
;
auto
&
req_io_buffer
=
cntl
->
request_attachment
();
auto
req_buffer_size
=
req_io_buffer
.
size
();
if
(
req_buffer_size
<
1
)
{
set_response_code
(
response
,
-
1
,
"req attachment is empty"
);
return
0
;
}
push_buffer
.
resize
(
0
);
push_buffer
.
reserve
(
req_buffer_size
);
const
char
*
data
=
(
const
char
*
)
cntl
->
request_attachment
().
fetch
(
const_cast
<
char
*>
(
push_buffer
.
data
()),
req_buffer_size
);
uint32_t
num
=
*
(
const
uint32_t
*
)
data
;
const
float
*
values
=
(
const
float
*
)(
data
+
sizeof
(
uint32_t
));
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
values
;
table_context
.
push_context
.
is_param
=
true
;
table_context
.
num
=
num
;
if
(
table
->
Push
(
table_context
)
!=
0
)
{
set_response_code
(
response
,
-
1
,
"PushDenseParam failed"
);
}
return
0
;
}
int32_t
BrpcPsService
::
PushDense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PushDense"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
req_buffer_size
=
request
.
data
().
size
();
if
(
req_buffer_size
<
1
)
{
// set_response_code(response, 0, "push dense data is empty");
return
0
;
}
CostTimer
timer
(
"pserver_server_push_dense"
);
/*
Push Content:
|--num--|---valuesData---|
|--4B---|----------------|
*/
uint32_t
num
=
*
(
const
uint32_t
*
)(
request
.
data
().
data
());
TableContext
table_context
;
table_context
.
value_type
=
Dense
;
table_context
.
push_context
.
values
=
(
const
float
*
)(
request
.
data
().
data
()
+
sizeof
(
uint32_t
));
table_context
.
num
=
num
;
// const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t));
if
(
table
->
Push
(
table_context
)
!=
0
)
{
// if (table->PushDense(values, num) != 0) {
set_response_code
(
response
,
-
1
,
"PushDense failed"
);
}
return
0
;
}
int32_t
BrpcPsService
::
Barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"
);
return
0
;
}
auto
trainer_id
=
request
.
client_id
();
auto
barrier_type
=
request
.
params
(
0
);
table
->
Barrier
(
trainer_id
,
barrier_type
);
return
0
;
}
int32_t
BrpcPsService
::
PushSparseParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PushSparseParam"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
&
push_data
=
request
.
data
();
if
(
push_data
.
size
()
<
1
)
{
// set_response_code(response, 0, "push sparse data is empty");
return
0
;
}
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"
);
return
0
;
}
const
uint32_t
num
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
request
.
params
(
0
).
c_str
()));
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const
uint64_t
*
keys
=
(
const
uint64_t
*
)
push_data
.
data
();
const
float
*
values
=
(
const
float
*
)(
push_data
.
data
()
+
sizeof
(
uint64_t
)
*
num
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
keys
;
table_context
.
push_context
.
values
=
values
;
table_context
.
push_context
.
is_param
=
true
;
table_context
.
num
=
num
;
// if (table->PushSparseParam(keys, values, num) != 0) {
if
(
table
->
Push
(
table_context
)
!=
0
)
{
set_response_code
(
response
,
-
1
,
"PushSparseParam error"
);
}
return
0
;
}
int32_t
BrpcPsService
::
PullGeoParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_geo_param"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
thread_local
std
::
string
push_sparse_request_buffer
;
auto
trainer_id
=
request
.
client_id
();
std
::
vector
<
float
>
values
;
std
::
vector
<
uint64_t
>
ids
;
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
pull_context
.
geo_pull_keys
=
&
ids
;
table_context
.
pull_context
.
geo_pull_values
=
&
values
;
table_context
.
trainer_id
=
trainer_id
;
table
->
Pull
(
table_context
);
// table->PullGeoParam(trainer_id, &values, &ids);
uint32_t
num
=
ids
.
size
();
cntl
->
response_attachment
().
append
(
reinterpret_cast
<
char
*>
(
&
num
),
sizeof
(
uint32_t
));
cntl
->
response_attachment
().
append
(
reinterpret_cast
<
char
*>
(
ids
.
data
()),
ids
.
size
()
*
sizeof
(
uint64_t
));
cntl
->
response_attachment
().
append
(
reinterpret_cast
<
char
*>
(
values
.
data
()),
values
.
size
()
*
sizeof
(
float
));
return
0
;
}
int32_t
BrpcPsService
::
PullSparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PullSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
&
req_io_buffer
=
cntl
->
request_attachment
();
auto
req_buffer_size
=
req_io_buffer
.
size
();
if
(
req_buffer_size
<
1
)
{
set_response_code
(
response
,
-
1
,
"req attachment is empty"
);
return
0
;
}
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"
);
return
0
;
}
CostTimer
timer
(
"pserver_server_pull_sparse"
);
const
uint32_t
num
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
request
.
params
(
0
).
c_str
()));
auto
dim
=
table
->
ValueAccesor
()
->
GetAccessorInfo
().
select_dim
;
thread_local
std
::
string
req_buffer
;
req_buffer
.
reserve
(
req_buffer_size
);
const
void
*
data
=
cntl
->
request_attachment
().
fetch
(
const_cast
<
char
*>
(
req_buffer
.
data
()),
req_buffer_size
);
auto
value
=
PullSparseValue
(
num
,
dim
);
value
.
DeserializeFromBytes
(
const_cast
<
void
*>
(
data
));
auto
res_data
=
butil
::
get_object
<
std
::
vector
<
float
>>
();
res_data
->
resize
(
num
*
dim
);
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
pull_context
.
pull_value
=
value
;
table_context
.
pull_context
.
values
=
res_data
->
data
();
table
->
Pull
(
table_context
);
// table->PullSparse(res_data->data(), value);
cntl
->
response_attachment
().
append
(
reinterpret_cast
<
char
*>
(
res_data
->
data
()),
res_data
->
size
()
*
sizeof
(
float
));
butil
::
return_object
(
res_data
);
return
0
;
}
int32_t
BrpcPsService
::
PushSparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->PushSparse"
,
platform
::
TracerEventType
::
Communication
,
1
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
&
push_data
=
request
.
data
();
if
(
push_data
.
size
()
<
1
)
{
// set_response_code(response, 0, "push sparse data is empty");
return
0
;
}
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key"
);
return
0
;
}
CostTimer
timer
(
"pserver_server_push_sparse"
);
const
uint32_t
num
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
request
.
params
(
0
).
c_str
()));
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
TableContext
table_context
;
table_context
.
value_type
=
Sparse
;
table_context
.
push_context
.
keys
=
(
const
uint64_t
*
)
push_data
.
data
();
table_context
.
push_context
.
values
=
(
const
float
*
)(
push_data
.
data
()
+
sizeof
(
uint64_t
)
*
num
);
table_context
.
num
=
num
;
// const uint64_t *keys = (const uint64_t *)push_data.data();
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num);
if
(
table
->
Push
(
table_context
)
!=
0
)
{
// if (table->PushSparse(keys, values, num) != 0) {
set_response_code
(
response
,
-
1
,
"PushSparse error"
);
}
return
0
;
}
int32_t
BrpcPsService
::
PrintTableStat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
std
::
pair
<
int64_t
,
int64_t
>
ret
=
table
->
PrintTableStat
();
paddle
::
framework
::
BinaryArchive
ar
;
ar
<<
ret
.
first
<<
ret
.
second
;
std
::
string
table_info
(
ar
.
Buffer
(),
ar
.
Length
());
response
.
set_data
(
table_info
);
return
0
;
}
int32_t
BrpcPsService
::
LoadOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 2 for path & load_param"
);
return
-
1
;
}
if
(
table
->
Load
(
request
.
params
(
0
),
request
.
params
(
1
))
!=
0
)
{
set_response_code
(
response
,
-
1
,
"table load failed"
);
return
-
1
;
}
return
0
;
}
int32_t
BrpcPsService
::
LoadAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
&
itr
:
table_map
)
{
if
(
LoadOneTable
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
LOG
(
ERROR
)
<<
"load table["
<<
itr
.
first
<<
"] failed"
;
return
-
1
;
}
}
return
0
;
}
int32_t
BrpcPsService
::
SaveOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 2, path&mode"
);
return
-
1
;
}
table
->
Flush
();
int32_t
feasign_size
=
0
;
VLOG
(
3
)
<<
"save table "
<<
request
.
params
(
0
)
<<
" "
<<
request
.
params
(
1
);
feasign_size
=
table
->
Save
(
request
.
params
(
0
),
request
.
params
(
1
));
if
(
feasign_size
<
0
)
{
set_response_code
(
response
,
-
1
,
"table save failed"
);
return
-
1
;
}
return
feasign_size
;
}
int32_t
BrpcPsService
::
SaveAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
GetTable
());
int32_t
feasign_size
=
0
;
for
(
auto
&
itr
:
table_map
)
{
feasign_size
=
SaveOneTable
(
itr
.
second
.
get
(),
request
,
response
,
cntl
);
if
(
feasign_size
<
0
)
{
LOG
(
ERROR
)
<<
"save table["
<<
itr
.
first
<<
"] failed"
;
return
-
1
;
}
}
return
0
;
}
int32_t
BrpcPsService
::
SaveCacheTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 3, path&mode"
);
return
-
1
;
}
table
->
Flush
();
int32_t
feasign_size
=
0
;
// if (_server->_shuffled_ins->size() <= 0) {
// LOG(WARNING) << "shuffled ins size <= 0";
//}
feasign_size
=
table
->
SaveCache
(
request
.
params
(
0
),
request
.
params
(
1
),
_server
->
_shuffled_ins
);
if
(
feasign_size
<
0
)
{
set_response_code
(
response
,
-
1
,
"table save failed"
);
return
-
1
;
}
return
feasign_size
;
}
int32_t
BrpcPsService
::
CacheShuffle
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
// start cache shuffle
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
3
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 3, "
"path&mode&cache_threshold"
);
return
-
1
;
}
table
->
Flush
();
double
cache_threshold
=
std
::
stod
(
request
.
params
(
2
));
LOG
(
INFO
)
<<
"cache threshold for cache shuffle: "
<<
cache_threshold
;
// auto shuffled_ins = paddle::ps::make_channel<std::pair<uint64_t,
// std::string>>();
// shuffled_ins->set_block_size(80000);
_server
->
StartS2S
();
std
::
function
<
std
::
future
<
int32_t
>
(
int
msg_type
,
int
to_pserver_id
,
const
std
::
string
&
msg
)
>
send_msg_func
=
[
this
](
int
msg_type
,
int
to_pserver_id
,
const
std
::
string
&
msg
)
->
std
::
future
<
int32_t
>
{
return
this
->
_server
->
SendPServer2PServerMsg
(
msg_type
,
to_pserver_id
,
msg
);
};
std
::
vector
<
Table
*>
table_ptrs
;
for
(
int
i
=
3
;
i
<
request
.
params_size
();
++
i
)
{
int
table_id
=
std
::
stoi
(
request
.
params
(
i
));
Table
*
table_ptr
=
_server
->
GetTable
(
table_id
);
table_ptrs
.
push_back
(
table_ptr
);
}
if
(
table_ptrs
.
empty
())
{
table_ptrs
.
push_back
(
table
);
}
table
->
CacheShuffle
(
request
.
params
(
0
),
request
.
params
(
1
),
cache_threshold
,
send_msg_func
,
_server
->
_shuffled_ins
,
table_ptrs
);
return
0
;
}
int32_t
BrpcPsService
::
GetCacheThreshold
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
table
->
Flush
();
double
cache_threshold
=
0.0
;
cache_threshold
=
table
->
GetCacheThreshold
();
if
(
cache_threshold
<
0
)
{
LOG
(
WARNING
)
<<
"wrong threshold: "
<<
cache_threshold
;
}
std
::
stringstream
ss
;
ss
<<
std
::
setprecision
(
15
)
<<
cache_threshold
;
std
::
string
cache_threshold_str
=
ss
.
str
();
response
.
set_data
(
cache_threshold_str
);
return
0
;
}
int32_t
BrpcPsService
::
Revert
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
&
itr
:
table_map
)
{
itr
.
second
->
Flush
();
itr
.
second
->
Revert
();
}
return
0
;
}
int32_t
BrpcPsService
::
CheckSavePrePatchDone
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
&
itr
:
table_map
)
{
itr
.
second
->
CheckSavePrePatchDone
();
}
return
0
;
}
int32_t
BrpcPsService
::
ShrinkTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
set_response_code
(
response
,
-
1
,
"PsRequestMessage.datas is requeired at least 1, threshold"
);
return
-
1
;
}
table
->
Flush
();
if
(
table
->
Shrink
(
request
.
params
(
0
))
!=
0
)
{
set_response_code
(
response
,
-
1
,
"table shrink failed"
);
return
-
1
;
}
VLOG
(
3
)
<<
"Pserver Shrink Finished"
;
return
0
;
}
int32_t
BrpcPsService
::
ClearOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
table
->
Flush
();
table
->
Clear
();
return
0
;
}
int32_t
BrpcPsService
::
ClearAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
GetTable
());
for
(
auto
&
itr
:
table_map
)
{
if
(
ClearOneTable
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
return
-
1
;
}
}
return
0
;
}
int32_t
BrpcPsService
::
StopServer
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
*
p_server
=
_server
;
std
::
thread
t_stop
([
p_server
]()
{
p_server
->
Stop
();
VLOG
(
3
)
<<
"Server Stoped"
;
});
t_stop
.
detach
();
return
0
;
}
int32_t
BrpcPsService
::
StopProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"server_%s_profile"
,
_rank
));
return
0
;
}
int32_t
BrpcPsService
::
StartProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
return
0
;
}
int32_t
BrpcPsService
::
PushGlobalStep
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
);
auto
req_buffer_size
=
request
.
data
().
size
();
if
(
req_buffer_size
<
1
)
{
set_response_code
(
response
,
0
,
"run_program data is empty"
);
return
0
;
}
const
int64_t
*
values
=
(
const
int64_t
*
)(
request
.
data
().
data
()
+
sizeof
(
uint32_t
));
auto
trainer_id
=
request
.
client_id
();
TableContext
context
;
context
.
trainer_id
=
trainer_id
;
context
.
push_context
.
push_steps
=
values
;
// if (table->PushDense(values, trainer_id) != 0) {
if
(
table
->
Push
(
context
)
!=
0
)
{
set_response_code
(
response
,
-
1
,
"run_program failed"
);
}
return
0
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/brpc_ps_server.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/server.h"
namespace
brpc
{
class
Controller
;
}
// namespace brpc
namespace
google
{
namespace
protobuf
{
class
Closure
;
class
RpcController
;
}
// namespace protobuf
}
// namespace google
namespace
paddle
{
namespace
distributed
{
class
PsRequestMessage
;
class
PsResponseMessage
;
class
Table
;
class
BrpcPsServer
:
public
PSServer
{
public:
BrpcPsServer
()
{}
virtual
~
BrpcPsServer
()
{}
virtual
uint64_t
Start
(
const
std
::
string
&
ip
,
uint32_t
port
);
virtual
int32_t
Stop
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
stoped_
=
true
;
cv_
.
notify_all
();
_server
.
Stop
(
1000
);
_server
.
Join
();
return
0
;
}
int32_t
Port
();
int32_t
StartS2S
()
override
;
::
std
::
future
<
int32_t
>
SendPServer2PServerMsg
(
int
msg_type
,
int
to_pserver_id
,
const
std
::
string
&
msg
)
override
;
int32_t
ReceiveFromPServer
(
int
msg_type
,
int
pserver_id
,
const
std
::
string
&
msg
)
override
;
private:
virtual
int32_t
Initialize
();
mutable
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
bool
stoped_
=
false
;
brpc
::
Server
_server
;
std
::
shared_ptr
<
PsBaseService
>
_service
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_pserver_channels
;
};
class
BrpcPsService
;
typedef
int32_t
(
BrpcPsService
::*
serviceHandlerFunc
)(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
class
BrpcPsService
:
public
PsBaseService
{
public:
int32_t
Initialize
()
override
;
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
private:
int32_t
InitializeShardInfo
();
int32_t
PullDense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
PushDense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
PushDenseParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
PushSparseParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
PullSparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
PullGeoParam
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
Barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
PushSparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
LoadOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
LoadAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
SaveOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
SaveAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
ShrinkTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
ClearOneTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
ClearAllTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
StopServer
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
StartProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
StopProfiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
PrintTableStat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
PushGlobalStep
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
CacheShuffle
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
SaveCacheTable
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
GetCacheThreshold
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
Revert
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
int32_t
CheckSavePrePatchDone
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
// NOLINT
brpc
::
Controller
*
cntl
);
bool
_is_initialize_shard_info
;
std
::
mutex
_initialize_shard_mutex
;
std
::
unordered_map
<
int32_t
,
serviceHandlerFunc
>
_service_handler_map
;
std
::
unordered_map
<
int32_t
,
serviceHandlerFunc
>
_msg_handler_map
;
std
::
vector
<
float
>
_ori_values
;
};
class
DownpourPServerBrpcClosure
:
public
PServerClosure
{
public:
DownpourPServerBrpcClosure
(
size_t
num
,
PServerCallBack
callback
)
:
PServerClosure
(
callback
)
{
_waiting_num
=
num
;
_cntls
.
resize
(
num
);
_requests
.
resize
(
num
);
_responses
.
resize
(
num
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
_cntls
[
i
].
reset
(
new
brpc
::
Controller
());
}
}
virtual
~
DownpourPServerBrpcClosure
()
{}
void
Run
()
override
{
if
(
_waiting_num
.
fetch_sub
(
1
)
==
1
)
{
_callback
(
this
);
delete
this
;
}
}
PsRequestMessage
*
request
(
size_t
i
)
{
return
&
_requests
[
i
];
}
PsResponseMessage
*
response
(
size_t
i
)
{
return
&
_responses
[
i
];
}
brpc
::
Controller
*
cntl
(
size_t
i
)
{
return
_cntls
[
i
].
get
();
}
int
check_response
(
size_t
request_idx
,
int
cmd_id
)
{
return
1
;
}
int
check_save_response
(
size_t
request_idx
,
int
cmd_id
)
{
return
1
;
}
private:
std
::
atomic
<
int32_t
>
_waiting_num
;
std
::
vector
<
PsRequestMessage
>
_requests
;
std
::
vector
<
PsResponseMessage
>
_responses
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Controller
>>
_cntls
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/brpc_utils.cc
0 → 100644
View file @
de2e6515
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include <arpa/inet.h>
#include <netdb.h>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
class
Variable
;
}
// namespace framework
}
// namespace paddle
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
distributed
{
framework
::
proto
::
VarType
::
Type
VarMessageToVarType
(
VariableMessage
::
Type
type
)
{
switch
(
type
)
{
case
VariableMessage
::
FP32
:
return
framework
::
proto
::
VarType
::
FP32
;
// NOLINT
case
VariableMessage
::
FP64
:
return
framework
::
proto
::
VarType
::
FP64
;
// NOLINT
case
VariableMessage
::
INT32
:
return
framework
::
proto
::
VarType
::
INT32
;
// NOLINT
case
VariableMessage
::
INT64
:
return
framework
::
proto
::
VarType
::
INT64
;
// NOLINT
case
VariableMessage
::
BOOL
:
return
framework
::
proto
::
VarType
::
BOOL
;
// NOLINT
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"VarMessageToVarType:Unsupported type %d"
,
type
));
}
}
void
SerializeToMultiVarMsgAndIOBuf
(
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_name_val
,
const
std
::
vector
<
std
::
string
>&
recv_var_name_val
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
MultiVarMsg
*
request
,
butil
::
IOBuf
*
iobuf
)
{
// 1. message_name
request
->
set_message_name
(
message_name
);
// 2. var_names
for
(
auto
&
send_var_name
:
send_var_name_val
)
{
request
->
add_send_var_names
(
send_var_name
);
}
for
(
auto
&
recv_var_name
:
recv_var_name_val
)
{
request
->
add_recv_var_names
(
recv_var_name
);
}
// 3. VarMessage
for
(
auto
&
send_var_name
:
send_var_name_val
)
{
auto
*
send_var_msg
=
request
->
add_var_messages
();
butil
::
IOBuf
temp_iobuf
;
send_var_msg
->
set_varname
(
send_var_name
);
framework
::
Variable
*
var
=
scope
->
FindVar
(
send_var_name
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
SerializeLodTensor
(
var
,
ctx
,
send_var_msg
,
&
temp_iobuf
);
}
else
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
SerializeSelectedRows
(
var
,
ctx
,
send_var_msg
,
&
temp_iobuf
);
}
iobuf
->
append
(
temp_iobuf
);
}
}
void
SerializeLodTensor
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
)
{
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
var_msg
->
set_type
(
::
paddle
::
distributed
::
LOD_TENSOR
);
const
framework
::
LoD
lod
=
tensor
->
lod
();
if
(
lod
.
size
()
>
0
)
{
var_msg
->
set_lod_level
(
lod
.
size
());
for
(
auto
&
each
:
lod
)
{
VarMsg
::
LodData
*
lod_inner
=
var_msg
->
add_lod
();
for
(
auto
&
d
:
each
)
{
lod_inner
->
add_lod_data
(
d
);
}
}
}
var_msg
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
framework
::
TransToProtoVarType
(
tensor
->
dtype
())));
for
(
auto
&
dim
:
phi
::
vectorize
(
tensor
->
dims
()))
{
var_msg
->
add_dims
(
dim
);
}
// IO Buffer
if
(
platform
::
is_cpu_place
(
tensor
->
place
()))
{
auto
data_len
=
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
());
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
&
data_len
),
8
);
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
tensor
->
data
()),
data_len
);
}
else
{
#ifdef PADDLE_WITH_CUDA
char
*
temp_ptr
=
new
char
[
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
())];
// NOLINT
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
).
stream
();
memory
::
Copy
(
platform
::
CPUPlace
(),
temp_ptr
,
tensor
->
place
(),
tensor
->
data
(),
tensor
->
numel
()
*
framework
::
SizeOfType
(
framework
::
TransToProtoVarType
(
tensor
->
dtype
())),
stream
);
auto
data_len
=
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
());
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
&
data_len
),
8
);
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
temp_ptr
),
data_len
);
delete
[]
temp_ptr
;
#endif
}
}
void
SerializeSelectedRows
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
)
{
phi
::
SelectedRows
*
slr
=
var
->
GetMutable
<
phi
::
SelectedRows
>
();
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
rows
=
slr
->
mutable_rows
();
var_msg
->
set_type
(
::
paddle
::
distributed
::
SELECTED_ROWS
);
var_msg
->
set_slr_height
(
slr
->
height
());
auto
*
var_data
=
var_msg
->
mutable_data
();
var_data
->
clear
();
var_data
->
resize
(
rows
->
size
()
*
sizeof
(
int64_t
));
char
*
data_ptr
=
const_cast
<
char
*>
(
var_data
->
data
());
memcpy
(
data_ptr
,
&
((
*
rows
)[
0
]),
rows
->
size
()
*
sizeof
(
int64_t
));
var_msg
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
framework
::
TransToProtoVarType
(
tensor
->
dtype
())));
for
(
auto
&
dim
:
phi
::
vectorize
(
tensor
->
dims
()))
{
var_msg
->
add_dims
(
dim
);
}
// IO Buffer
if
(
platform
::
is_cpu_place
(
tensor
->
place
()))
{
auto
data_len
=
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
());
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
&
data_len
),
8
);
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
tensor
->
data
()),
data_len
);
}
else
{
#ifdef PADDLE_WITH_CUDA
char
*
temp_ptr
=
new
char
[
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
())];
// NOLINT
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
).
stream
();
memory
::
Copy
(
platform
::
CPUPlace
(),
temp_ptr
,
tensor
->
place
(),
tensor
->
data
(),
tensor
->
numel
()
*
framework
::
SizeOfType
(
framework
::
TransToProtoVarType
(
tensor
->
dtype
())),
stream
);
auto
data_len
=
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
());
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
&
data_len
),
8
);
iobuf
->
append
(
reinterpret_cast
<
const
char
*>
(
temp_ptr
),
data_len
);
delete
[]
temp_ptr
;
#endif
}
}
void
DeserializeFromMultiVarMsgAndIOBuf
(
const
MultiVarMsg
&
multi_msg
,
const
butil
::
IOBuf
*
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Scope
*
scope
)
{
butil
::
IOBufBytesIterator
io_buffer_itr
(
*
iobuf
);
// size_t shard_buffer_remain = res_io_buffer.size();
for
(
int
recv_var_index
=
0
;
recv_var_index
<
multi_msg
.
send_var_names_size
();
++
recv_var_index
)
{
const
auto
&
msg
=
multi_msg
.
var_messages
(
recv_var_index
);
auto
*
var
=
scope
->
Var
(
msg
.
varname
());
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
LOD_TENSOR
)
{
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
else
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
SELECTED_ROWS
)
{
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
}
}
void
DeserializeFromMultiVarMsgAndIOBuf
(
const
MultiVarMsg
&
multi_msg
,
const
butil
::
IOBuf
*
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
)
{
butil
::
IOBufBytesIterator
io_buffer_itr
(
*
iobuf
);
// size_t shard_buffer_remain = res_io_buffer.size();
for
(
int
recv_var_index
=
0
;
recv_var_index
<
multi_msg
.
send_var_names_size
();
++
recv_var_index
)
{
const
auto
&
msg
=
multi_msg
.
var_messages
(
recv_var_index
);
auto
*
var
=
scope
->
FindVar
(
msg
.
varname
());
PADDLE_ENFORCE_NE
(
var
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Not find variable %s in scope."
,
msg
.
varname
()));
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
LOD_TENSOR
)
{
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
else
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
SELECTED_ROWS
)
{
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
}
}
void
DeserializeLodTensor
(
framework
::
Variable
*
var
,
const
VarMsg
&
msg
,
butil
::
IOBufBytesIterator
&
io_buffer_itr
,
// NOLINT
const
platform
::
DeviceContext
&
ctx
)
{
const
auto
place
=
ctx
.
GetPlace
();
framework
::
LoDTensor
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
std
::
vector
<
int
>
vec_dim
;
for
(
auto
&
x
:
msg
.
dims
())
{
vec_dim
.
push_back
(
x
);
}
tensor
->
Resize
(
phi
::
make_ddim
(
vec_dim
));
framework
::
LoD
lod
;
for
(
int
i
=
0
;
i
<
msg
.
lod_level
();
++
i
)
{
framework
::
Vector
<
size_t
>
v
;
for
(
int
j
=
0
;
j
<
msg
.
lod
(
i
).
lod_data_size
();
++
j
)
{
v
.
push_back
(
msg
.
lod
(
i
).
lod_data
(
j
));
}
lod
.
push_back
(
v
);
}
tensor
->
set_lod
(
lod
);
void
*
tensor_data
=
tensor
->
mutable_data
(
place
,
framework
::
TransToPhiDataType
(
VarMessageToVarType
(
msg
.
data_type
())));
// IO Buffer
if
(
platform
::
is_cpu_place
(
place
))
{
unsigned
long
data_len
;
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)(
&
data_len
),
8
);
// NOLINT
io_buffer_itr
.
copy_and_forward
(
tensor_data
,
data_len
);
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#ifdef PADDLE_WITH_CUDA
unsigned
long
data_len
;
// NOLINT
char
*
temp_ptr
=
new
char
[
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
())];
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)(
&
data_len
),
8
);
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)
temp_ptr
,
data_len
);
// NOLINT
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
).
stream
();
memory
::
Copy
(
place
,
tensor_data
,
platform
::
CPUPlace
(),
(
void
*
)
temp_ptr
,
// NOLINT
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
()),
stream
);
delete
[]
temp_ptr
;
#endif
}
}
void
DeserializeSelectedRows
(
framework
::
Variable
*
var
,
const
VarMsg
&
msg
,
butil
::
IOBufBytesIterator
&
io_buffer_itr
,
// NOLINT
const
platform
::
DeviceContext
&
ctx
)
{
const
auto
place
=
ctx
.
GetPlace
();
auto
*
slr
=
var
->
GetMutable
<
phi
::
SelectedRows
>
();
framework
::
Tensor
*
tensor
=
slr
->
mutable_value
();
slr
->
set_height
(
msg
.
slr_height
());
std
::
vector
<
int64_t
>
tmp_rows
(
msg
.
dims
()[
0
]);
memcpy
(
tmp_rows
.
data
(),
msg
.
data
().
data
(),
msg
.
dims
()[
0
]
*
sizeof
(
int64_t
));
slr
->
set_rows
(
tmp_rows
);
std
::
vector
<
int
>
vec_dim
;
for
(
auto
&
x
:
msg
.
dims
())
{
vec_dim
.
push_back
(
x
);
}
tensor
->
Resize
(
phi
::
make_ddim
(
vec_dim
));
void
*
tensor_data
=
tensor
->
mutable_data
(
place
,
framework
::
TransToPhiDataType
(
VarMessageToVarType
(
msg
.
data_type
())));
// IO Buffer
if
(
platform
::
is_cpu_place
(
place
))
{
unsigned
long
data_len
;
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)(
&
data_len
),
8
);
// NOLINT
io_buffer_itr
.
copy_and_forward
(
tensor_data
,
data_len
);
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#ifdef PADDLE_WITH_CUDA
char
*
temp_ptr
=
new
char
[
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
())];
// NOLINT
unsigned
long
data_len
;
// NOLINT
io_buffer_itr
.
copy_and_forward
((
void
*
)(
&
data_len
),
8
);
// NOLINT
io_buffer_itr
.
copy_and_forward
(
temp_ptr
,
data_len
);
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
).
stream
();
memory
::
Copy
(
place
,
tensor_data
,
platform
::
CPUPlace
(),
temp_ptr
,
tensor
->
numel
()
*
framework
::
DataTypeSize
(
tensor
->
dtype
()),
stream
);
delete
[]
temp_ptr
;
#endif
}
}
std
::
string
GetIntTypeEndpoint
(
const
std
::
string
&
ip
,
const
uint32_t
&
port
)
{
// There are usually two forms of IP address: ip(int) / ip (hostname)
// If there're some problem with DNS, or ip triggers the bug of Brpc
// We will try to get the IP address of the domain name manually again
std
::
string
ip_port
=
ip
+
":"
+
std
::
to_string
(
port
);
struct
hostent
*
hp
=
NULL
;
hp
=
gethostbyname
(
ip
.
c_str
());
if
(
NULL
==
hp
)
{
LOG
(
ERROR
)
<<
"Brpc Start failed, ip_port= "
<<
ip_port
<<
" , Error infomation: "
<<
hstrerror
(
h_errno
);
}
int
i
=
0
;
char
*
int_ip
=
NULL
;
while
(
hp
->
h_addr_list
[
i
]
!=
NULL
)
{
int_ip
=
inet_ntoa
(
*
(
struct
in_addr
*
)
hp
->
h_addr_list
[
i
]);
VLOG
(
3
)
<<
"Brpc Get host by name, host:"
<<
ip
<<
" -> ip: "
<<
int_ip
;
break
;
}
std
::
string
str_ip
=
int_ip
;
std
::
string
int_ip_port
=
str_ip
+
":"
+
std
::
to_string
(
port
);
return
int_ip_port
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/ps/service/brpc_utils.h
0 → 100644
View file @
de2e6515
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <netdb.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/phi/backends/dynload/port.h"
namespace
butil
{
class
IOBuf
;
class
IOBufBytesIterator
;
}
// namespace butil
namespace
grpc
{
class
ByteBuffer
;
}
// namespace grpc
namespace
paddle
{
namespace
framework
{
class
Scope
;
class
Variable
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
void
SerializeToMultiVarMsgAndIOBuf
(
const
std
::
string
&
message_name
,
const
std
::
vector
<
std
::
string
>&
send_var_name_val
,
const
std
::
vector
<
std
::
string
>&
recv_var_name_val
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
MultiVarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
);
void
SerializeLodTensor
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
);
void
SerializeSelectedRows
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
butil
::
IOBuf
*
iobuf
);
// Deserialize for Server
void
DeserializeFromMultiVarMsgAndIOBuf
(
const
MultiVarMsg
&
multi_msg
,
const
butil
::
IOBuf
*
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
Scope
*
scope
);
// Deserialize for Client
void
DeserializeFromMultiVarMsgAndIOBuf
(
const
MultiVarMsg
&
multi_msg
,
const
butil
::
IOBuf
*
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
);
void
DeserializeLodTensor
(
framework
::
Variable
*
var
,
const
VarMsg
&
msg
,
butil
::
IOBufBytesIterator
&
iobuf
,
// NOLINT
const
platform
::
DeviceContext
&
ctx
);
void
DeserializeSelectedRows
(
framework
::
Variable
*
var
,
const
VarMsg
&
msg
,
butil
::
IOBufBytesIterator
&
iobuf
,
// NOLINT
const
platform
::
DeviceContext
&
ctx
);
std
::
string
GetIntTypeEndpoint
(
const
std
::
string
&
ip
,
const
uint32_t
&
port
);
}
// namespace distributed
}
// namespace paddle
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment