Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/boxing/slice_boxing_util.h"
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/boxing/eager_boxing_logger.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
namespace oneflow {
namespace private_details {
Maybe<one::Tensor> PreprocessInputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,
const std::string& log_prefix) {
const auto& tensor_placement = JUST(tensor->parallel_desc());
if (tensor_placement->device_type() == DeviceType::kCPU
|| tensor_placement->device_type() == DeviceType::kCUDA) {
return tensor;
}
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
Symbol<ParallelDesc> new_placement = JUST(ReplaceDeviceType(tensor_placement, DeviceType::kCPU));
const auto& boxing_interpreter =
JUST(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(
tensor_nd_sbp, tensor_nd_sbp, tensor_placement, new_placement, *tensor->shape()));
Singleton<const EagerBoxingLogger>::Get()->Log(
*JUST(boxing_interpreter->boxing_interpreter_status()), log_prefix);
return JUST(boxing_interpreter->Interpret(tensor, tensor_nd_sbp, tensor_nd_sbp, tensor_placement,
new_placement));
}
Maybe<one::Tensor> PostprocessOutputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,
Symbol<PlacedNdSbp> placed_nd_sbp,
const std::string& log_prefix) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_nd_sbp == placed_nd_sbp->nd_sbp())
<< Error::RuntimeError()
<< "Compute slice boxing failed. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possible";
CHECK_OR_RETURN(tensor_placement->EqualsIgnoringDeviceType(*placed_nd_sbp->placement()))
<< Error::RuntimeError()
<< "Compute slice boxing failed. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possible";
if (JUST(tensor->parallel_desc()) == placed_nd_sbp->placement()) { return tensor; }
const auto& boxing_interpreter =
JUST(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(
placed_nd_sbp->nd_sbp(), placed_nd_sbp->nd_sbp(), JUST(tensor->parallel_desc()),
placed_nd_sbp->placement(), *tensor->shape()));
Singleton<const EagerBoxingLogger>::Get()->Log(
*JUST(boxing_interpreter->boxing_interpreter_status()), log_prefix);
return JUST(boxing_interpreter->Interpret(tensor, placed_nd_sbp->nd_sbp(),
placed_nd_sbp->nd_sbp(), JUST(tensor->parallel_desc()),
placed_nd_sbp->placement()));
}
const std::string& LogPrefix4EagerSliceBoxingType(EagerSliceBoxingType boxing_type) {
static thread_local const HashMap<EagerSliceBoxingType, std::string> boxing_type2log_prefix = {
{EagerSliceBoxingType::kNaiveBToS, "\t\tInternal boxing of naive-b-to-s, "},
{EagerSliceBoxingType::kNaivePToB, "\t\tInternal boxing of naive-p-to-b, "},
{EagerSliceBoxingType::kNaivePToS, "\t\tInternal boxing of naive-p-to-s, "},
{EagerSliceBoxingType::kNaiveSToB, "\t\tInternal boxing of naive-s-to-b, "},
{EagerSliceBoxingType::kNaiveSToP, "\t\tInternal boxing of naive-s-to-p, "},
{EagerSliceBoxingType::kNaiveSToS, "\t\tInternal boxing of naive-s-to-s, "}};
return CHECK_JUST(MapAt(boxing_type2log_prefix, boxing_type));
}
} // namespace private_details
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_BOXING_SLICE_BOXING_UTIL_H_
#define ONEFLOW_CORE_BOXING_SLICE_BOXING_UTIL_H_
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/placed_nd_sbp.h"
#include "oneflow/core/job/parallel_desc.h"
namespace oneflow {
enum class EagerSliceBoxingType : unsigned int;
namespace private_details {
// Copy to cpu if device of input tensor is not cpu or cuda, otherwise return self
Maybe<one::Tensor> PreprocessInputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,
const std::string& log_prefix);
// Copy to corresponding device if device of output tensor is not same with that of placed_nd_sbp,
// otherwise return self
Maybe<one::Tensor> PostprocessOutputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,
Symbol<PlacedNdSbp> placed_nd_sbp,
const std::string& log_prefix);
const std::string& LogPrefix4EagerSliceBoxingType(EagerSliceBoxingType boxing_type);
} // namespace private_details
enum class EagerSliceBoxingType : unsigned int {
kNaiveBToS = 0,
kNaivePToB = 1,
kNaivePToS = 2,
kNaiveSToB = 3,
kNaiveSToP = 4,
kNaiveSToS = 5
};
template<EagerSliceBoxingType boxing_type>
struct EagerSliceBoxingAutoConvert {
template<Maybe<one::Tensor> (*func)(const std::shared_ptr<one::Tensor>&, Symbol<PlacedNdSbp>,
Symbol<PlacedNdSbp>)>
static Maybe<one::Tensor> Call(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
std::shared_ptr<one::Tensor> processed_in_tensor =
JUST(private_details::PreprocessInputTensor4SliceBoxing(
tensor, private_details::LogPrefix4EagerSliceBoxingType(boxing_type)));
const auto& new_in =
JUST(PlacedNdSbp::New(in->nd_sbp(), JUST(processed_in_tensor->parallel_desc())));
Symbol<ParallelDesc> new_out_placement = JUST(ReplaceDeviceType(
out->placement(), JUST(processed_in_tensor->parallel_desc())->device_type()));
const auto& new_out = JUST(PlacedNdSbp::New(out->nd_sbp(), new_out_placement));
std::shared_ptr<one::Tensor> out_tensor = JUST(func(processed_in_tensor, new_in, new_out));
return JUST(private_details::PostprocessOutputTensor4SliceBoxing(
out_tensor, out, private_details::LogPrefix4EagerSliceBoxingType(boxing_type)));
}
};
#define EAGER_SLICE_BOXING_WARPPER(fn_ptr, boxing_type) \
(&EagerSliceBoxingAutoConvert<boxing_type>::Call<fn_ptr>)
} // namespace oneflow
#endif // ONEFLOW_CORE_BOXING_SLICE_BOXING_UTIL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/boxing/eager_boxing_logger.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/placement_sbp_util.h"
#include "oneflow/core/framework/placed_nd_sbp.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace {
Maybe<one::OpExpr> MakeToConsistentOpExpr() {
std::shared_ptr<one::OpExpr> op_expr =
JUST(one::CastToConsistentOpExpr::New(*JUST(UniqueStr("cast_to_consistent"))));
return op_expr;
}
static constexpr auto* GetLocalToConsistentOpExpr =
DECORATE(&MakeToConsistentOpExpr, ThreadLocalCachedCopiable);
Maybe<one::Tensor> ReinterpterConsistentTensor(const std::shared_ptr<one::Tensor>& tensor,
const Shape& shape,
Symbol<ParallelDesc> parallel_desc,
Symbol<NdSbp> nd_sbp) {
const auto& op = JUST(GetLocalToConsistentOpExpr());
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<DataType>("dtype", tensor->dtype()->data_type()));
const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));
std::shared_ptr<Shape> pyhsical_shape =
JUST(GetPhysicalShape(shape, *nd_sbp, *parallel_desc, JUST(*parallel_id)));
std::shared_ptr<one::Tensor> x = JUST(tensor->cur_rank_phy_tensor());
if (*x->shape() != *pyhsical_shape) { x = JUST(one::functional::Reshape(x, *pyhsical_shape)); }
return JUST(one::OpInterpUtil::Dispatch<one::Tensor>(
*op, {x}, one::OpExprInterpContext(attrs, parallel_desc, nd_sbp)));
}
Maybe<one::Tensor> Apply1DBoxing(const std::shared_ptr<one::Tensor>& input, Symbol<NdSbp> in_nd_sbp,
Symbol<NdSbp> out_nd_sbp, Symbol<ParallelDesc> in_parallel_desc,
Symbol<ParallelDesc> out_parallel_desc) {
const auto& boxing_interpreter =
JUST(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(
in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc, *input->shape()));
Singleton<const EagerBoxingLogger>::Get()->Log(
*JUST(boxing_interpreter->boxing_interpreter_status()),
/* prefix */ "\t\tInternal boxing of symmetric-acyclic-nd-sbp-to-nd-sbp, ");
return JUST(boxing_interpreter->Interpret(input, in_nd_sbp, out_nd_sbp, in_parallel_desc,
out_parallel_desc));
}
// NOLINTBEGIN(maybe-need-error-msg)
Maybe<void> RawCheckSymmetricAcyclicNdSbpBoxing(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->nd_sbp() != out->nd_sbp());
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), out->nd_sbp()->sbp_parallel_size());
CHECK_GT_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
JUST(CheckIsNdSbpBoxingAcyclicWithDecompose(in, out, logical_shape));
return Maybe<void>::Ok();
}
// NOLINTEND(maybe-need-error-msg)
static constexpr auto* CheckSymmetricAcyclicNdSbpBoxing =
DECORATE(&RawCheckSymmetricAcyclicNdSbpBoxing, ThreadLocalCopiable);
} // namespace
Maybe<one::Tensor> SymmetricAcyclicNdSbpBoxing(const std::shared_ptr<one::Tensor>& input,
Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(input->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(input->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement())
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
const auto& out_nd_sbp = out->nd_sbp();
const auto& out_parallel_desc = out->placement();
std::shared_ptr<one::Tensor> output;
const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc));
if (out_parallel_id->has_value()) {
const auto& tensor_meta = JUST(input->consistent_tensor_meta());
const auto& naive_transformations =
JUST(DecomposeIntoNaiveTransformations(tensor_meta, out_nd_sbp));
std::shared_ptr<one::Tensor> tensor = input;
for (const auto& naive_transformation : *naive_transformations) {
const auto& sub_tensor_meta = naive_transformation.consistent_tensor_meta;
tensor = JUST(ReinterpterConsistentTensor(tensor, sub_tensor_meta->shape(),
sub_tensor_meta->parallel_desc(),
sub_tensor_meta->nd_sbp()));
tensor =
JUST(Apply1DBoxing(tensor, sub_tensor_meta->nd_sbp(), naive_transformation.dst_nd_sbp,
sub_tensor_meta->parallel_desc(), sub_tensor_meta->parallel_desc()));
}
output =
JUST(ReinterpterConsistentTensor(tensor, *input->shape(), out_parallel_desc, out_nd_sbp));
} else {
one::ConsistentTensorMeta tensor_meta(input->shape(), input->dtype()->data_type(), out_nd_sbp,
out_parallel_desc);
const auto& tensor_impl = JUST(
one::EagerConsistentTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false));
output = std::make_shared<one::ConsistentTensor>(tensor_impl);
}
return output;
}
COMMAND(RegisterBoxingFunction("symmetric-acyclic-nd-sbp-to-nd-sbp",
CheckSymmetricAcyclicNdSbpBoxing, &SymmetricAcyclicNdSbpBoxing));
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace {
Maybe<void> RawCheckSymmetricBToP(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(NdSbpIsAllBroadcast(*in->nd_sbp()));
CHECK_OR_RETURN(NdSbpIsAllPartialSum(*out->nd_sbp()));
CHECK_OR_RETURN(in->placement() == out->placement());
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
static constexpr auto* CheckSymmetricBToP =
DECORATE(&RawCheckSymmetricBToP, ThreadLocalCachedCopiable);
} // namespace
Maybe<one::Tensor> SymmetricBToP(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement())
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
int64_t root = JUST(tensor_placement->MachineId4ParallelId(0));
std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());
if (root == GlobalProcessCtx::Rank()) {
// do nothing
} else {
local_tensor = JUST(one::functional::ZerosLike(local_tensor));
}
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(),
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),
tensor->dtype()));
}
COMMAND(RegisterBoxingFunction("symmetric-b-to-p", CheckSymmetricBToP, &SymmetricBToP));
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/register/tensor_slice_view.h"
#include "oneflow/core/job/nd_sbp_util.h"
namespace oneflow {
namespace {
bool RawIsBroadcastSbp(Symbol<SbpParallel> sbp_parallel) {
return sbp_parallel->has_broadcast_parallel();
}
static constexpr auto* IsBroadcastSbp = DECORATE(&RawIsBroadcastSbp, ThreadLocalCached);
bool RawIsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }
static constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached);
// NOLINTBEGIN(maybe-need-error-msg)
Maybe<void> RawCheckSymmetricB2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(IsBroadcastSbp(SymbolOf(in->nd_sbp()->sbp_parallel(0))));
CHECK_OR_RETURN(IsSplitSbp(SymbolOf(out->nd_sbp()->sbp_parallel(0))));
CHECK_OR_RETURN(in->placement() == out->placement());
return Maybe<void>::Ok();
}
// NOLINTEND(maybe-need-error-msg)
static constexpr auto* CheckSymmetricB2S =
DECORATE(&RawCheckSymmetricB2S, ThreadLocalCachedCopiable);
} // namespace
Maybe<one::Tensor> SymmetricB2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement())
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
const auto& local_shape = *tensor->shape();
std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());
const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(tensor_placement));
if (parallel_id->has_value()) {
const TensorSliceView& in_slice = GetTensorSliceView4ParallelId(
*tensor_placement->hierarchy(), *tensor_nd_sbp, local_shape, JUST(*parallel_id));
CHECK(!in_slice.IsEmpty());
const TensorSliceView& out_slice = GetTensorSliceView4ParallelId(
*tensor_placement->hierarchy(), *out->nd_sbp(), local_shape, JUST(*parallel_id));
CHECK(!out_slice.IsEmpty());
const TensorSliceView& intersection = out_slice.Intersect(in_slice);
CHECK(!intersection.IsEmpty());
const std::vector<Range>& range_vec = intersection.range_vec();
std::vector<int64_t> start;
std::vector<int64_t> stop;
std::vector<int64_t> step(range_vec.size(), 1);
for (const auto& range : range_vec) {
start.emplace_back(range.begin());
stop.emplace_back(range.end());
}
local_tensor = JUST(one::functional::Slice(local_tensor, start, stop, step,
/*enable_view_slice=*/false));
}
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(),
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),
tensor->dtype()));
}
COMMAND(RegisterBoxingFunction("symmetric-b-to-s", CheckSymmetricB2S, &SymmetricB2S));
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
namespace oneflow {
namespace {
bool RawIsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }
static constexpr auto* IsSplitSbp = DECORATE(&RawIsSplitSbp, ThreadLocalCached);
bool RawIsPartialSumSbp(Symbol<SbpParallel> sbp_parallel) {
return sbp_parallel->has_partial_sum_parallel();
}
static constexpr auto* IsPartialSumSbp = DECORATE(&RawIsPartialSumSbp, ThreadLocalCached);
Maybe<one::UserOpExpr> EagerSymmetricSToP(Symbol<ParallelDesc> parallel_desc,
Symbol<SbpParallel> src_sbp, const Shape& logical_shape) {
return one::OpBuilder("eager_symmetric_s_to_p", *JUST(UniqueStr("eager_symmetric_s_to_p")))
.Input("in")
.Output("out")
.Attr<int64_t>("in_split_axis", src_sbp->split_parallel().axis())
.Attr<std::string>("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf()))
.Build();
}
static constexpr auto* CachedEagerSymmetricSToPOpExpr =
DECORATE(&EagerSymmetricSToP, ThreadLocalCachedCopiable);
// NOLINTBEGIN(maybe-need-error-msg)
Maybe<void> RawCheckSymmetricSToP(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(IsSplitSbp(in->nd_sbp()->sbp_parallel(0)));
CHECK_OR_RETURN(IsPartialSumSbp(out->nd_sbp()->sbp_parallel(0)));
CHECK_OR_RETURN(in->placement() == out->placement());
return Maybe<void>::Ok();
}
// NOLINTEND(maybe-need-error-msg)
static constexpr auto* CheckSymmetricSToP =
DECORATE(&RawCheckSymmetricSToP, ThreadLocalCachedCopiable);
} // namespace
Maybe<one::Tensor> SymmetricSToP(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement())
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
std::shared_ptr<one::OpExpr> op_expr = JUST(CachedEagerSymmetricSToPOpExpr(
tensor_placement, SymbolOf(tensor_nd_sbp->sbp_parallel(0)), *tensor->shape()));
return JUST(one::OpInterpUtil::Dispatch<one::Tensor>(*op_expr, {tensor}));
}
COMMAND(RegisterBoxingFunction("symmetric-s-to-p", CheckSymmetricSToP, &SymmetricSToP));
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/operator/operator.h"
namespace oneflow {
namespace {
// NOLINTBEGIN(maybe-need-error-msg)
Maybe<void> RawCheckUnflattenHierarchy(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_GT_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
for (int i = 0; i < out->nd_sbp()->sbp_parallel_size(); ++i) {
const auto& sbp_parallel = out->nd_sbp()->sbp_parallel(i);
CHECK_OR_RETURN(sbp_parallel == out->nd_sbp()->sbp_parallel(0)) << "nd_sbp axis: " << i;
}
CHECK_EQ_OR_RETURN(in->placement()->device_type(), out->placement()->device_type());
CHECK_EQ_OR_RETURN(in->placement()->parallel_num(), out->placement()->parallel_num());
ParallelConf unflattened_parallel_conf(in->placement()->parallel_conf());
unflattened_parallel_conf.mutable_hierarchy()->CopyFrom(
out->placement()->parallel_conf().hierarchy());
const auto& unflatten_placement = SymbolOf(ParallelDesc(unflattened_parallel_conf));
CHECK_OR_RETURN(unflatten_placement == out->placement())
<< "The output placement is not a hierarch-unflattened version of the input placement";
for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num();
++in_parallel_id) {
const auto& in_physical_shape =
JUST(GetPhysicalShape(logical_shape, *in->nd_sbp(), *in->placement(), in_parallel_id));
const auto& out_physical_shape =
JUST(GetPhysicalShape(logical_shape, *out->nd_sbp(), *out->placement(), in_parallel_id));
CHECK_EQ_OR_RETURN(*in_physical_shape, *out_physical_shape);
}
return Maybe<void>::Ok();
}
// NOLINTEND(maybe-need-error-msg)
} // namespace
static constexpr auto* CheckUnflattenHierarchy =
DECORATE(&RawCheckUnflattenHierarchy, ThreadLocalCachedCopiable);
Maybe<one::Tensor> UnflattenHierarchy(const std::shared_ptr<one::Tensor>& tensor,
Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement())
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor());
const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype()));
}
COMMAND(RegisterBoxingFunction("unflatten-hierarchy", CheckUnflattenHierarchy,
&UnflattenHierarchy));
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ccl/ccl.h"
#include "oneflow/core/device/nccl_util.h"
#include "oneflow/core/framework/transport_util.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/data_type_seq.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/thread/thread_manager.h"
#include "oneflow/core/job/eager_nccl_comm_manager.h"
#ifdef WITH_ROCM
#include "oneflow/core/ep/rocm/cuda_stream.h"
#else
#include "oneflow/core/ep/cuda/cuda_stream.h"
#endif
#include "oneflow/core/common/constant.h"
namespace oneflow {
namespace ccl {
namespace {
Maybe<void> InitBroadcastRankHeap(std::vector<int64_t>* ranks, const ParallelDesc& parallel_desc,
int64_t root) {
CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), parallel_desc.sorted_machine_ids().size());
ranks->resize(parallel_desc.parallel_num());
int64_t root_index = -1;
for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {
int64_t machine_id = JUST(parallel_desc.MachineId4ParallelId(parallel_id));
if (machine_id == root) { root_index = parallel_id; }
(*ranks)[parallel_id] = machine_id;
}
CHECK_NE_OR_RETURN(root_index, -1);
std::swap((*ranks)[0], (*ranks)[root_index]);
return Maybe<void>::Ok();
}
int64_t RingDecrease(int64_t n, int64_t size) { return (n - 1 + size) % size; }
int64_t RingIncrease(int64_t n, int64_t size) { return (n + 1 + size) % size; }
template<typename T>
void VecAdd(size_t size, T* out, const T* in0, const T* in1) {
size_t thread_num = Singleton<ThreadPool>::Get()->thread_num();
BalancedSplitter bs(size, thread_num);
MultiThreadLoop(thread_num, [&](size_t thread_idx) {
size_t end = bs.At(thread_idx).end();
for (size_t i = bs.At(thread_idx).begin(); i < end; ++i) { out[i] = in0[i] + in1[i]; }
});
}
} // namespace
template<typename T, ReduceType reduce_type>
struct DtypeAllReduce;
template<typename T>
struct DtypeAllReduce<T, kSum> {
static Maybe<void> Call(const void* void_in, void* void_out, size_t elem_cnt,
Symbol<ParallelDesc> parallel_desc) {
int64_t parallel_num = parallel_desc->parallel_num();
if (parallel_num == 1) {
if (void_in != void_out) { std::memcpy(void_out, void_in, elem_cnt * sizeof(T)); }
return Maybe<void>::Ok();
}
const T* in = reinterpret_cast<const T*>(void_in);
T* out = reinterpret_cast<T*>(void_out);
BalancedSplitter bs(elem_cnt, parallel_num);
auto recv_buffer = std::make_unique<T[]>(bs.At(0).size());
Optional<int64_t> parallel_id;
JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));
const auto& rank_group = JUST(RankGroup::New(parallel_desc));
TransportToken transport_token =
JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
for (int64_t i = 0, part_id = JUST(parallel_id); i < parallel_num - 1;
++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const T* send_ptr = nullptr;
if (i == 0) {
send_ptr = &in[bs.At(send_part_id).begin()];
} else {
send_ptr = &out[bs.At(send_part_id).begin()];
}
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
T* recv_ptr = recv_buffer.get();
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
const T* cur_in = &in[bs.At(recv_part_id).begin()];
T* cur_out = &out[bs.At(recv_part_id).begin()];
if (recv_size > 0) { VecAdd(recv_size, cur_out, cur_in, recv_ptr); }
}
for (int64_t i = 0, part_id = RingIncrease(JUST(parallel_id), parallel_num);
i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const T* send_ptr = &out[bs.At(send_part_id).begin()];
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
T* recv_ptr = &out[bs.At(recv_part_id).begin()];
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
}
return Maybe<void>::Ok();
}
};
#define MAKE_ALL_REDUCE_ENTRY(func_name, T, reduce_type) func_name<T, reduce_type>::Call
DEFINE_STATIC_SWITCH_FUNC(Maybe<void>, DtypeAllReduce, MAKE_ALL_REDUCE_ENTRY,
MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), CCL_REDUCE_TYPE_CTRV_SEQ);
#undef MAKE_ALL_REDUCE_ENTRY
template<>
Maybe<void> AllReduce<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream) {
return SwitchDtypeAllReduce(SwitchCase(dtype, reduce_type), in, out, elem_cnt, parallel_desc);
}
template<typename T, ReduceType reduce_type>
struct DtypeReduceScatter;
template<typename T>
struct DtypeReduceScatter<T, kSum> {
static Maybe<void> Call(const void* void_in, void* void_out, size_t elem_cnt,
Symbol<ParallelDesc> parallel_desc) {
int64_t parallel_num = parallel_desc->parallel_num();
if (parallel_num == 1) {
if (void_in != void_out) { std::memcpy(void_out, void_in, elem_cnt * sizeof(T)); }
return Maybe<void>::Ok();
}
const T* in = reinterpret_cast<const T*>(void_in);
T* out = reinterpret_cast<T*>(void_out);
BalancedSplitter bs(elem_cnt * parallel_num, parallel_num);
const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));
CHECK_OR_RETURN(opt_parallel_id->has_value());
int64_t parallel_id = JUST(*opt_parallel_id);
auto recv_buffer = std::make_unique<T[]>(bs.At(0).size());
const auto& rank_group = JUST(RankGroup::New(parallel_desc));
TransportToken transport_token =
JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
for (int64_t i = 0, part_id = RingDecrease(parallel_id, parallel_num); i < parallel_num - 1;
++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const T* send_ptr = nullptr;
if (i == 0) {
send_ptr = &in[bs.At(send_part_id).begin()];
} else {
send_ptr = out;
}
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
T* recv_ptr = recv_buffer.get();
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
const T* cur_in = &in[bs.At(recv_part_id).begin()];
if (recv_size > 0) { VecAdd(recv_size, out, cur_in, recv_ptr); }
}
return Maybe<void>::Ok();
}
};
#define MAKE_REDUCE_SCATTER_ENTRY(func_name, T, reduce_type) func_name<T, reduce_type>::Call
DEFINE_STATIC_SWITCH_FUNC(Maybe<void>, DtypeReduceScatter, MAKE_REDUCE_SCATTER_ENTRY,
MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), CCL_REDUCE_TYPE_CTRV_SEQ);
#undef MAKE_REDUCE_SCATTER_ENTRY
template<>
Maybe<void> ReduceScatter<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt,
DataType dtype, ReduceType reduce_type,
Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream) {
return SwitchDtypeReduceScatter(SwitchCase(dtype, reduce_type), in, out, elem_cnt, parallel_desc);
}
template<>
Maybe<void> AllGather<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt, DataType dtype,
Symbol<ParallelDesc> parallel_desc, ep::Stream* stream) {
int64_t parallel_num = parallel_desc->parallel_num();
if (parallel_num == 1) {
if (in != out) { std::memcpy(out, in, elem_cnt * GetSizeOfDataType(dtype)); }
return Maybe<void>::Ok();
}
char* char_out = reinterpret_cast<char*>(out);
size_t chunk_size = elem_cnt * GetSizeOfDataType(dtype);
BalancedSplitter bs(chunk_size * parallel_num, parallel_num);
const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));
CHECK_OR_RETURN(opt_parallel_id->has_value());
const auto& rank_group = JUST(RankGroup::New(parallel_desc));
TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
int64_t parallel_id = JUST(*opt_parallel_id);
// In-place operation will happen if in == out + parallel_id * chunk_size
if (in != &char_out[parallel_id * chunk_size]) {
memcpy(&char_out[parallel_id * chunk_size], in, chunk_size);
}
for (int64_t i = 0, part_id = parallel_id; i < parallel_num - 1;
++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const void* send_ptr = &char_out[bs.At(send_part_id).begin()];
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
void* recv_ptr = &char_out[bs.At(recv_part_id).begin()];
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<void*>(send_ptr);
*size = send_size;
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size;
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
}
return Maybe<void>::Ok();
}
template<>
Maybe<void> Broadcast<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt, DataType dtype,
int64_t root, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream) {
CHECK_EQ_OR_RETURN(parallel_desc->device_type(), DeviceType::kCPU);
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
return CpuBroadcast(in, out, buffer_size, root, parallel_desc, transport_token);
}
Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root,
Symbol<ParallelDesc> parallel_desc,
const TransportToken& transport_token) {
static thread_local std::vector<int64_t> rank_heap{};
JUST(InitBroadcastRankHeap(&rank_heap, *parallel_desc, root));
auto Send = [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = (root == GlobalProcessCtx::Rank() ? const_cast<void*>(in) : out);
*size = buffer_size;
*Cb = [] {};
return Maybe<void>::Ok();
};
auto Recv = [&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = out;
*size = buffer_size;
*Cb = [] {};
return Maybe<void>::Ok();
};
{
NaiveAsyncTransportCtx transport_ctx(transport_token, Send, Recv);
JUST(TransportUtil::ReceiveDataFromParentInHeap(rank_heap, transport_token, &transport_ctx));
JUST_MSG(transport_ctx.WaitDone(), kAsymmetricCodeErrorMsg);
}
{
NaiveAsyncTransportCtx transport_ctx(transport_token, Send, Recv);
JUST(TransportUtil::SendDataToChildrenInHeap(rank_heap, transport_token, &transport_ctx));
if (GlobalProcessCtx::Rank() == root && out != in) { std::memcpy(out, in, buffer_size); }
JUST_MSG(transport_ctx.WaitDone(), kAsymmetricCodeErrorMsg);
}
return Maybe<void>::Ok();
}
template<typename T, ReduceType reduce_type>
struct DtypeReduce;
template<typename T>
struct DtypeReduce<T, kSum> {
static Maybe<void> Call(const void* void_in, void* void_out, size_t elem_cnt, int64_t root,
Symbol<ParallelDesc> parallel_desc) {
const T* in = reinterpret_cast<const T*>(void_in);
T* out = reinterpret_cast<T*>(void_out);
int64_t parallel_num = parallel_desc->parallel_num();
BalancedSplitter bs(elem_cnt, parallel_num);
size_t size = root == GlobalProcessCtx::Rank() && void_in != void_out ? 0 : bs.At(0).size();
T* tmp_out = nullptr;
// void_out is only used on rank root and ignored for other ranks.
auto tmp_out_buffer = std::make_unique<T[]>(size);
int64_t parallel_id_of_root =
JUST(parallel_desc->ParallelId4MachineDeviceId(root, GlobalProcessCtx::LocalRank(root)));
if (root == GlobalProcessCtx::Rank() && void_in != void_out) {
tmp_out = &reinterpret_cast<T*>(void_out)[bs.At(parallel_id_of_root).begin()];
} else {
tmp_out = tmp_out_buffer.get();
}
auto recv_buffer = std::make_unique<T[]>(bs.At(0).size());
Optional<int64_t> parallel_id;
JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));
const auto& rank_group = JUST(RankGroup::New(parallel_desc));
TransportToken transport_token =
JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
for (int64_t i = 0, part_id = RingDecrease(JUST(parallel_id), parallel_num);
i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const T* send_ptr = nullptr;
if (i == 0) {
send_ptr = &in[bs.At(send_part_id).begin()];
} else {
send_ptr = tmp_out;
}
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
T* recv_ptr = recv_buffer.get();
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
const T* cur_in = &in[bs.At(recv_part_id).begin()];
if (recv_size > 0) { VecAdd(recv_size, tmp_out, cur_in, recv_ptr); }
}
if (root == GlobalProcessCtx::Rank() && void_in == void_out) {
memcpy(&out[bs.At(parallel_id_of_root).begin()], tmp_out,
bs.At(parallel_id_of_root).size() * sizeof(T));
}
for (int64_t i = 0, part_id = RingIncrease(parallel_id_of_root, parallel_num);
i < parallel_num - 1; ++i, part_id = RingIncrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
int64_t src_rank = JUST(parallel_desc->MachineId4ParallelId(send_part_id));
const T* send_ptr = tmp_out;
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = part_id;
T* recv_ptr = &out[bs.At(recv_part_id).begin()];
size_t recv_size = bs.At(recv_part_id).size();
if (send_size > 0 && src_rank == GlobalProcessCtx::Rank()) {
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
UNIMPLEMENTED_THEN_RETURN();
});
JUST(TransportUtil::SendDataToRank(root, transport_token, &ctx));
JUST(ctx.WaitDone());
}
if (recv_size > 0 && root == GlobalProcessCtx::Rank()) {
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
UNIMPLEMENTED_THEN_RETURN();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
JUST(TransportUtil::ReceiveDataFromRank(src_rank, transport_token, &ctx));
JUST(ctx.WaitDone());
}
}
return Maybe<void>::Ok();
}
};
#define MAKE_REDUCE_ENTRY(func_name, T, reduce_type) func_name<T, reduce_type>::Call
DEFINE_STATIC_SWITCH_FUNC(Maybe<void>, DtypeReduce, MAKE_REDUCE_ENTRY,
MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), CCL_REDUCE_TYPE_CTRV_SEQ);
#undef MAKE_REDUCE_ENTRY
template<>
Maybe<void> Reduce<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, int64_t root,
Symbol<ParallelDesc> parallel_desc, ep::Stream* stream) {
return SwitchDtypeReduce(SwitchCase(dtype, reduce_type), in, out, elem_cnt, root, parallel_desc);
}
#ifdef WITH_CUDA
std::pair<ncclComm_t, int64_t> RawGetNcclCommAndPeerNcclRank(int64_t peer_process_id) {
std::set<std::pair<int64_t, int64_t>> device_set;
const int64_t& rank = GlobalProcessCtx::Rank();
const int64_t peer_nccl_rank = (peer_process_id > rank) ? 1 : 0;
device_set.emplace(rank, GlobalProcessCtx::LocalRank());
device_set.emplace(peer_process_id, GlobalProcessCtx::LocalRank(peer_process_id));
return {CHECK_NOTNULL(Singleton<EagerNcclCommMgr>::Get())->GetCommForDevice(device_set),
peer_nccl_rank};
}
auto* GetNcclCommAndPeerNcclRank = DECORATE(&RawGetNcclCommAndPeerNcclRank, ThreadLocal);
#endif
#ifdef WITH_ROCM
std::pair<ncclComm_t, int64_t> RawGetNcclCommAndPeerNcclRank(int64_t peer_process_id) {
std::set<std::pair<int64_t, int64_t>> device_set;
const int64_t& rank = GlobalProcessCtx::Rank();
const int64_t peer_nccl_rank = (peer_process_id > rank) ? 1 : 0;
device_set.emplace(rank, GlobalProcessCtx::LocalRank());
device_set.emplace(peer_process_id, GlobalProcessCtx::LocalRank(peer_process_id));
return {CHECK_NOTNULL(Singleton<EagerNcclCommMgr>::Get())->GetCommForDevice(device_set),
peer_nccl_rank};
}
auto* GetNcclCommAndPeerNcclRank = DECORATE(&RawGetNcclCommAndPeerNcclRank, ThreadLocal);
#endif
template<>
Maybe<void> Send<DeviceType::kCPU>(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
NaiveAsyncTransportCtx transport_ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<void*>(in);
*size = buffer_size;
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
UNIMPLEMENTED_THEN_RETURN();
});
JUST(TransportUtil::SendDataToRank(dst, transport_token, &transport_ctx));
JUST(transport_ctx.WaitDone());
return Maybe<void>::Ok();
}
#ifdef WITH_CUDA
template<>
Maybe<void> Send<DeviceType::kCUDA>(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,
ep::Stream* stream) {
#if NCCL_VERSION_CODE >= 2700
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst);
OF_NCCL_CHECK_OR_RETURN(ncclSend(in, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
#else
UNIMPLEMENTED_THEN_RETURN() << "GPU send is only supported when nccl version >= 2.7"
#endif
}
#endif
#ifdef WITH_ROCM
template<>
Maybe<void> Send<DeviceType::kCUDA>(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst);
OF_NCCL_CHECK_OR_RETURN(ncclSend(in, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
}
#endif
template<>
Maybe<void> Recv<DeviceType::kCPU>(void* out, size_t elem_cnt, DataType dtype, int64_t src,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
NaiveAsyncTransportCtx transport_ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
UNIMPLEMENTED_THEN_RETURN();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = out;
*size = buffer_size;
*Cb = [] {};
return Maybe<void>::Ok();
});
JUST(TransportUtil::ReceiveDataFromRank(src, transport_token, &transport_ctx));
JUST(transport_ctx.WaitDone());
return Maybe<void>::Ok();
}
#ifdef WITH_CUDA
template<>
Maybe<void> Recv<DeviceType::kCUDA>(void* out, size_t elem_cnt, DataType dtype, int64_t src,
ep::Stream* stream) {
#if NCCL_VERSION_CODE >= 2700
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src);
OF_NCCL_CHECK_OR_RETURN(ncclRecv(out, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
#else
UNIMPLEMENTED_THEN_RETURN() << "GPU recv is only supported when nccl version >= 2.7"
#endif
}
#endif
#ifdef WITH_ROCM
template<>
Maybe<void> Recv<DeviceType::kCUDA>(void* out, size_t elem_cnt, DataType dtype, int64_t src,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src);
OF_NCCL_CHECK_OR_RETURN(ncclRecv(out, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
}
#endif
} // namespace ccl
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_CCL_CCL_H_
#define ONEFLOW_CORE_CCL_CCL_H_
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/device_type.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/switch_func.h"
#include "oneflow/core/ep/include/stream.h"
namespace oneflow {
class DeviceCtx;
class ParallelDesc;
class TransportToken;
// collective communication library
namespace ccl {
#define CCL_REDUCE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(kSum)
enum ReduceType {
kInvalidReduceFunctorType = 0,
#define DEFINE_REDUCE_TYPE_ENUM_VALUE(enum_value) enum_value,
OF_PP_FOR_EACH_TUPLE(DEFINE_REDUCE_TYPE_ENUM_VALUE, CCL_REDUCE_TYPE_SEQ)
#undef DEFINE_REDUCE_TYPE_ENUM_VALUE
kReduceTypeSize
};
#define CCL_REDUCE_TYPE_CTRV_SEQ \
MAKE_TYPED_CTRV_SEQ(ReduceType, \
OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, CCL_REDUCE_TYPE_SEQ))
template<DeviceType device_type>
Maybe<void> AllReduce(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> ReduceScatter(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> AllGather(const void* in, void* out, size_t elem_cnt, DataType dtype,
Symbol<ParallelDesc> parallel_desc, ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> Broadcast(const void* in, void* out, size_t elem_cnt, DataType dtype, int64_t root,
Symbol<ParallelDesc> parallel_desc, ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> Reduce(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, int64_t root, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream);
Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root,
Symbol<ParallelDesc> parallel_desc, const TransportToken& transport_token);
} // namespace ccl
} // namespace oneflow
#endif // ONEFLOW_CORE_CCL_CCL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
CommNet::~CommNet() {
ready_cbs_.Close();
ready_cb_poller_.join();
}
void* CommNet::NewActorReadId() { return new ActorReadContext; }
void CommNet::DeleteActorReadId(void* actor_read_id) {
auto actor_read_ctx = static_cast<ActorReadContext*>(actor_read_id);
CHECK(actor_read_ctx->waiting_list.empty());
delete actor_read_ctx;
}
void CommNet::Read(void* actor_read_id, int64_t src_machine_id, void* src_token, void* dst_token) {
auto actor_read_ctx = static_cast<ActorReadContext*>(actor_read_id);
ReadContext* read_ctx = new ReadContext;
read_ctx->actor_read_ctx = actor_read_ctx;
auto do_read = [this, read_ctx, src_machine_id, src_token, dst_token]() {
DoRead(read_ctx, src_machine_id, src_token, dst_token);
};
AddWorkToStream(actor_read_id, do_read, true);
}
void CommNet::AddReadCallBack(void* actor_read_id, std::function<void()> callback) {
AddWorkToStream(actor_read_id, callback, false);
}
void CommNet::ReadDone(void* read_id) {
ReadContext* read_ctx = static_cast<ReadContext*>(read_id);
ActorReadContext* actor_read_ctx = read_ctx->actor_read_ctx;
CommNetItem item;
std::unique_lock<std::mutex> lck(actor_read_ctx->waiting_list_mtx);
CHECK(!actor_read_ctx->waiting_list.empty());
CHECK(actor_read_ctx->waiting_list.front().callback == nullptr);
actor_read_ctx->waiting_list.pop_front();
while (true) {
if (actor_read_ctx->waiting_list.empty()) { break; }
item = actor_read_ctx->waiting_list.front();
actor_read_ctx->waiting_list.pop_front();
CHECK(item.callback);
ready_cbs_.Send(item.callback);
if (item.is_read) { break; }
}
delete read_ctx;
}
void CommNet::AddWorkToStream(void* actor_read_id, const std::function<void()>& cb, bool is_read) {
auto actor_read_ctx = static_cast<ActorReadContext*>(actor_read_id);
std::unique_lock<std::mutex> lck(actor_read_ctx->waiting_list_mtx);
if (actor_read_ctx->waiting_list.empty()) {
ready_cbs_.Send(cb);
} else {
CommNetItem work_item(is_read, cb);
actor_read_ctx->waiting_list.emplace_back(work_item);
}
if (is_read) {
CommNetItem empty_cb;
actor_read_ctx->waiting_list.emplace_back(empty_cb);
}
}
CommNet::CommNet() {
int64_t this_machine_id = GlobalProcessCtx::Rank();
for (int64_t i : Singleton<ResourceDesc, ForSession>::Get()->process_ranks()) {
if (i == this_machine_id) { continue; }
peer_machine_id_.insert(i);
}
ready_cb_poller_ = std::thread([this]() {
std::function<void()> cb;
while (ready_cbs_.Receive(&cb) == kChannelStatusSuccess) { cb(); }
});
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_COMM_NETWORK_H_
#define ONEFLOW_CORE_COMM_NETWORK_COMM_NETWORK_H_
#define DEPRECATED __attribute__((deprecated))
#include "oneflow/core/lazy/actor/actor_message.h"
#include "oneflow/core/common/platform.h"
#include "oneflow/core/common/channel.h"
namespace oneflow {
struct CommNetItem {
bool is_read;
std::function<void()> callback;
CommNetItem() : CommNetItem(false, nullptr) {}
CommNetItem(bool read, const std::function<void()>& cb) : is_read(read), callback(cb) {}
};
class CommNet {
public:
OF_DISALLOW_COPY_AND_MOVE(CommNet);
virtual ~CommNet();
// "RegisterMemory" will return a Token, after "RegisterMemoryDone",
// we can use this token to use the "Read"
virtual void* RegisterMemory(void* ptr, size_t byte_size) = 0;
virtual void UnRegisterMemory(void* token) = 0;
// Stream
void* NewActorReadId();
void DeleteActorReadId(void* actor_read_id);
void Read(void* actor_read_id, int64_t src_machine_id, void* src_token, void* dst_token);
void AddReadCallBack(void* actor_read_id, std::function<void()> callback);
void ReadDone(void* read_id);
virtual void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) = 0;
protected:
CommNet();
virtual void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) = 0;
const HashSet<int64_t>& peer_machine_id() { return peer_machine_id_; }
Channel<std::function<void()>> ready_cbs_;
private:
friend class Singleton<CommNet>;
void AddWorkToStream(void* actor_read_id, const std::function<void()>& cb, bool is_read);
struct ActorReadContext;
struct ReadContext {
ActorReadContext* actor_read_ctx;
};
struct ActorReadContext {
std::mutex waiting_list_mtx;
std::list<CommNetItem> waiting_list;
};
HashSet<int64_t> peer_machine_id_;
std::thread ready_cb_poller_;
};
template<typename MemDescType>
class CommNetIf : public CommNet {
public:
OF_DISALLOW_COPY_AND_MOVE(CommNetIf);
CommNetIf() : CommNet() {}
virtual ~CommNetIf() {}
void* RegisterMemory(void* ptr, size_t byte_size) override {
std::unique_lock<std::mutex> lck(mem_descs_mtx_);
MemDescType* mem_desc = NewMemDesc(ptr, byte_size);
CHECK(mem_descs_.insert(mem_desc).second);
return mem_desc;
}
void UnRegisterMemory(void* token) override {
std::unique_lock<std::mutex> lck(mem_descs_mtx_);
MemDescType* mem_desc = static_cast<MemDescType*>(token);
delete mem_desc;
CHECK_EQ(mem_descs_.erase(mem_desc), 1);
}
protected:
virtual MemDescType* NewMemDesc(void* ptr, size_t byte_size) = 0;
const HashSet<MemDescType*>& mem_descs() { return mem_descs_; }
private:
std::mutex mem_descs_mtx_;
HashSet<MemDescType*> mem_descs_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMM_NETWORK_COMM_NETWORK_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef __linux__
#include "oneflow/core/comm_network/epoll/epoll_comm_network.h"
#include "glog/logging.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/job/global_for.h"
#include <netinet/tcp.h>
namespace oneflow {
namespace {
static const int32_t kInvlidPort = 0;
sockaddr_in GetSockAddr(const std::string& addr, uint16_t port) {
sockaddr_in sa;
sa.sin_family = AF_INET;
sa.sin_port = htons(port);
PCHECK(inet_pton(AF_INET, addr.c_str(), &(sa.sin_addr)) == 1)
<< "addr: " << addr << ", port: " << port;
return sa;
}
int SockListen(int listen_sockfd, int32_t* listen_port, int32_t total_machine_num) {
// System designated available port if listen_port == kInvlidPort, otherwise, the configured port
// is used.
sockaddr_in sa = GetSockAddr("0.0.0.0", *listen_port);
int reuse = 1;
int ret_setopt =
setsockopt(listen_sockfd, SOL_SOCKET, SO_REUSEADDR, (const void*)&reuse, sizeof(int));
CHECK_EQ(ret_setopt, 0);
int bind_result = bind(listen_sockfd, reinterpret_cast<sockaddr*>(&sa), sizeof(sa));
{
sockaddr_in bound_sock;
socklen_t bound_sock_size = sizeof(bound_sock);
getsockname(listen_sockfd, reinterpret_cast<sockaddr*>(&bound_sock), &bound_sock_size);
if (*listen_port != kInvlidPort) {
CHECK_EQ(*listen_port, static_cast<int32_t>(ntohs(bound_sock.sin_port)));
} else {
*listen_port = static_cast<int32_t>(ntohs(bound_sock.sin_port));
}
}
if (bind_result == 0) {
PCHECK(listen(listen_sockfd, total_machine_num) == 0);
LOG(INFO) << "CommNet:Epoll listening on "
<< "0.0.0.0:" + std::to_string(*listen_port);
} else {
PCHECK(errno == EACCES || errno == EADDRINUSE) << "SockListen errno: " << errno;
}
return bind_result;
}
std::string GenPortKey(int64_t machine_id) { return "EpollPort/" + std::to_string(machine_id); }
void PushPort(int64_t machine_id, uint16_t port) {
Singleton<CtrlClient>::Get()->PushKV(GenPortKey(machine_id), std::to_string(port));
}
void ClearPort(int64_t machine_id) {
Singleton<CtrlClient>::Get()->ClearKV(GenPortKey(machine_id));
}
uint16_t PullPort(int64_t machine_id) {
uint16_t port = 0;
Singleton<CtrlClient>::Get()->PullKV(
GenPortKey(machine_id), [&](const std::string& v) { port = oneflow_cast<uint16_t>(v); });
return port;
}
} // namespace
EpollCommNet::~EpollCommNet() {
for (size_t i = 0; i < pollers_.size(); ++i) {
VLOG(1) << "CommNet Thread " << i << " finish";
pollers_[i]->Stop();
}
OF_ENV_BARRIER();
for (IOEventPoller* poller : pollers_) { delete poller; }
for (auto& pair : sockfd2helper_) { delete pair.second; }
}
void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_msg) {
SocketMsg msg;
msg.msg_type = SocketMsgType::kActor;
msg.actor_msg = actor_msg;
if (actor_msg.IsDataRegstMsgToConsumer()) {
msg.actor_msg.set_comm_net_token(actor_msg.regst()->comm_net_token());
}
GetSocketHelper(dst_machine_id)->AsyncWrite(msg);
}
void EpollCommNet::SendTransportMsg(int64_t dst_machine_id, const TransportMsg& transport_msg) {
SocketMsg msg;
msg.msg_type = SocketMsgType::kTransport;
msg.transport_msg = transport_msg;
SendSocketMsg(dst_machine_id, msg);
}
void EpollCommNet::SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg) {
GetSocketHelper(dst_machine_id)->AsyncWrite(msg);
}
SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) {
SocketMemDesc* mem_desc = new SocketMemDesc;
mem_desc->mem_ptr = ptr;
mem_desc->byte_size = byte_size;
return mem_desc;
}
EpollCommNet::EpollCommNet() : CommNetIf() {
pollers_.resize(Singleton<ResourceDesc, ForSession>::Get()->CommNetWorkerNum(), nullptr);
for (size_t i = 0; i < pollers_.size(); ++i) { pollers_[i] = new IOEventPoller; }
InitSockets();
for (IOEventPoller* poller : pollers_) { poller->Start(); }
}
void EpollCommNet::InitSockets() {
int64_t this_machine_id = GlobalProcessCtx::Rank();
auto this_machine = Singleton<ResourceDesc, ForSession>::Get()->machine(this_machine_id);
int64_t total_machine_num = Singleton<ResourceDesc, ForSession>::Get()->process_ranks().size();
machine_id2sockfd_.assign(total_machine_num, -1);
sockfd2helper_.clear();
size_t poller_idx = 0;
auto NewSocketHelper = [&](int sockfd) {
IOEventPoller* poller = pollers_[poller_idx];
poller_idx = (poller_idx + 1) % pollers_.size();
return new SocketHelper(sockfd, poller);
};
// listen
int listen_sockfd = socket(AF_INET, SOCK_STREAM, 0);
int32_t this_listen_port = kInvlidPort;
{
if (this_machine.data_port_agent() != -1) {
this_listen_port = this_machine.data_port_agent();
} else if (Singleton<EnvDesc>::Get()->data_port() != -1) {
this_listen_port = Singleton<EnvDesc>::Get()->data_port();
}
}
CHECK_EQ(SockListen(listen_sockfd, &this_listen_port, total_machine_num), 0);
CHECK_NE(this_listen_port, 0);
PushPort(this_machine_id, this_listen_port);
int32_t src_machine_count = 0;
// connect
for (int64_t peer_id : peer_machine_id()) {
if (peer_id < this_machine_id) {
++src_machine_count;
continue;
}
uint16_t peer_port = PullPort(peer_id);
auto peer_machine = Singleton<ResourceDesc, ForSession>::Get()->machine(peer_id);
sockaddr_in peer_sockaddr = GetSockAddr(peer_machine.addr(), peer_port);
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
const int val = 1;
PCHECK(setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char*)&val, sizeof(int)) == 0);
PCHECK(connect(sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), sizeof(peer_sockaddr))
== 0);
ssize_t n = write(sockfd, &this_machine_id, sizeof(int64_t));
PCHECK(n == sizeof(int64_t));
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
machine_id2sockfd_[peer_id] = sockfd;
}
// accept
HashSet<int64_t> processed_ranks;
FOR_RANGE(int32_t, idx, 0, src_machine_count) {
sockaddr_in peer_sockaddr;
socklen_t len = sizeof(peer_sockaddr);
int sockfd = accept(listen_sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), &len);
PCHECK(sockfd != -1);
int64_t peer_rank;
ssize_t n = read(sockfd, &peer_rank, sizeof(int64_t));
PCHECK(n == sizeof(int64_t));
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
CHECK(processed_ranks.emplace(peer_rank).second);
machine_id2sockfd_[peer_rank] = sockfd;
}
PCHECK(close(listen_sockfd) == 0);
ClearPort(this_machine_id);
// useful log
FOR_RANGE(int64_t, machine_id, 0, total_machine_num) {
VLOG(2) << "machine " << machine_id << " sockfd " << machine_id2sockfd_[machine_id];
}
}
SocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id) {
int sockfd = machine_id2sockfd_.at(machine_id);
return sockfd2helper_.at(sockfd);
}
void EpollCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) {
SocketMsg msg;
msg.msg_type = SocketMsgType::kRequestWrite;
msg.request_write_msg.src_token = src_token;
msg.request_write_msg.dst_machine_id = GlobalProcessCtx::Rank();
msg.request_write_msg.dst_token = dst_token;
msg.request_write_msg.read_id = read_id;
GetSocketHelper(src_machine_id)->AsyncWrite(msg);
}
} // namespace oneflow
#endif // __linux__
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_
#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_
#ifdef __linux__
#include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/comm_network/epoll/socket_helper.h"
#include "oneflow/core/comm_network/epoll/socket_memory_desc.h"
namespace oneflow {
class EpollCommNet final : public CommNetIf<SocketMemDesc> {
public:
OF_DISALLOW_COPY_AND_MOVE(EpollCommNet);
~EpollCommNet();
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg);
void SendTransportMsg(int64_t dst_machine_id, const TransportMsg& msg);
private:
SocketMemDesc* NewMemDesc(void* ptr, size_t byte_size) override;
friend class Singleton<EpollCommNet>;
EpollCommNet();
void InitSockets();
SocketHelper* GetSocketHelper(int64_t machine_id);
void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) override;
std::vector<IOEventPoller*> pollers_;
std::vector<int> machine_id2sockfd_;
HashMap<int, SocketHelper*> sockfd2helper_;
};
} // namespace oneflow
#endif // __linux__
#endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef __linux__
#include "oneflow/core/comm_network/epoll/io_event_poller.h"
#include <sys/eventfd.h>
namespace oneflow {
const int IOEventPoller::max_event_num_ = 32;
IOEventPoller::IOEventPoller() {
epfd_ = epoll_create1(0);
ep_events_ = new epoll_event[max_event_num_];
io_handlers_.clear();
break_epoll_loop_fd_ = eventfd(0, 0);
PCHECK(break_epoll_loop_fd_ != -1);
AddFdWithOnlyReadHandler(break_epoll_loop_fd_, []() { VLOG(1) << "Break Epoll Loop"; });
}
IOEventPoller::~IOEventPoller() {
for (IOHandler* handler : io_handlers_) {
PCHECK(close(handler->fd) == 0);
delete handler;
}
delete[] ep_events_;
PCHECK(close(epfd_) == 0);
}
void IOEventPoller::AddFd(int fd, std::function<void()> read_handler,
std::function<void()> write_handler) {
AddFd(fd, &read_handler, &write_handler);
}
void IOEventPoller::AddFdWithOnlyReadHandler(int fd, std::function<void()> read_handler) {
AddFd(fd, &read_handler, nullptr);
}
void IOEventPoller::Start() { thread_ = std::thread(&IOEventPoller::EpollLoop, this); }
void IOEventPoller::Stop() {
uint64_t break_epoll_loop_event = 1;
PCHECK(write(break_epoll_loop_fd_, &break_epoll_loop_event, 8) == 8);
thread_.join();
}
void IOEventPoller::AddFd(int fd, std::function<void()>* read_handler,
std::function<void()>* write_handler) {
// Set Fd NONBLOCK
int opt = fcntl(fd, F_GETFL);
PCHECK(opt != -1);
PCHECK(fcntl(fd, F_SETFL, opt | O_NONBLOCK) == 0);
// Set CLOEXEC
opt = fcntl(fd, F_GETFD);
PCHECK(opt != -1);
PCHECK(fcntl(fd, F_SETFD, opt | FD_CLOEXEC) == 0);
// New IOHandler on Heap
IOHandler* io_handler = new IOHandler;
if (read_handler) { io_handler->read_handler = *read_handler; }
if (write_handler) { io_handler->write_handler = *write_handler; }
io_handler->fd = fd;
io_handlers_.push_front(io_handler);
// Add Fd to Epoll
epoll_event ep_event;
ep_event.events = EPOLLET;
if (read_handler) { ep_event.events |= EPOLLIN; }
if (write_handler) { ep_event.events |= EPOLLOUT; }
ep_event.data.ptr = io_handler;
PCHECK(epoll_ctl(epfd_, EPOLL_CTL_ADD, fd, &ep_event) == 0);
}
void IOEventPoller::EpollLoop() {
while (true) {
int event_num = epoll_wait(epfd_, ep_events_, max_event_num_, -1);
if (event_num == -1) {
PCHECK(errno == EINTR);
continue;
}
const epoll_event* cur_event = ep_events_;
for (int event_idx = 0; event_idx < event_num; ++event_idx, ++cur_event) {
auto io_handler = static_cast<IOHandler*>(cur_event->data.ptr);
PCHECK(!(cur_event->events & EPOLLERR)) << "fd: " << io_handler->fd;
if (io_handler->fd == break_epoll_loop_fd_) { return; }
if (cur_event->events & EPOLLIN) {
if (cur_event->events & EPOLLRDHUP) {
LOG(FATAL) << "fd " << io_handler->fd << " closed by peer";
} else {
io_handler->read_handler();
}
}
if (cur_event->events & EPOLLOUT) { io_handler->write_handler(); }
}
}
}
} // namespace oneflow
#endif // __linux__
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_IO_EVENT_POLLER_H_
#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_IO_EVENT_POLLER_H_
#include "oneflow/core/comm_network/epoll/socket_message.h"
#ifdef OF_PLATFORM_POSIX
namespace oneflow {
class IOEventPoller final {
public:
OF_DISALLOW_COPY_AND_MOVE(IOEventPoller);
IOEventPoller();
~IOEventPoller();
void AddFd(int fd, std::function<void()> read_handler, std::function<void()> write_handler);
void AddFdWithOnlyReadHandler(int fd, std::function<void()> read_handler);
void Start();
void Stop();
private:
struct IOHandler {
IOHandler() {
read_handler = []() { UNIMPLEMENTED(); };
write_handler = []() { UNIMPLEMENTED(); };
fd = -1;
}
std::function<void()> read_handler;
std::function<void()> write_handler;
int fd;
};
void AddFd(int fd, std::function<void()>* read_handler, std::function<void()>* write_handler);
void EpollLoop();
static const int max_event_num_;
int epfd_;
epoll_event* ep_events_;
std::forward_list<IOHandler*> io_handlers_;
int break_epoll_loop_fd_;
std::thread thread_;
};
} // namespace oneflow
#endif // OF_PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_IO_EVENT_POLLER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef __linux__
#include "oneflow/core/comm_network/epoll/socket_helper.h"
namespace oneflow {
SocketHelper::SocketHelper(int sockfd, IOEventPoller* poller) {
read_helper_ = new SocketReadHelper(sockfd);
write_helper_ = new SocketWriteHelper(sockfd, poller);
poller->AddFd(
sockfd, [this]() { read_helper_->NotifyMeSocketReadable(); },
[this]() { write_helper_->NotifyMeSocketWriteable(); });
}
SocketHelper::~SocketHelper() {
delete read_helper_;
delete write_helper_;
}
void SocketHelper::AsyncWrite(const SocketMsg& msg) { write_helper_->AsyncWrite(msg); }
} // namespace oneflow
#endif // __linux__
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_HELPER_H_
#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_HELPER_H_
#include "oneflow/core/comm_network/epoll/io_event_poller.h"
#include "oneflow/core/comm_network/epoll/socket_read_helper.h"
#include "oneflow/core/comm_network/epoll/socket_write_helper.h"
#ifdef OF_PLATFORM_POSIX
namespace oneflow {
class SocketHelper final {
public:
OF_DISALLOW_COPY_AND_MOVE(SocketHelper);
SocketHelper() = delete;
~SocketHelper();
SocketHelper(int sockfd, IOEventPoller* poller);
void AsyncWrite(const SocketMsg& msg);
private:
SocketReadHelper* read_helper_;
SocketWriteHelper* write_helper_;
};
} // namespace oneflow
#endif // OF_PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_HELPER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MEMORY_DESC_H_
#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MEMORY_DESC_H_
#include "oneflow/core/comm_network/epoll/socket_memory_desc.h"
#ifdef OF_PLATFORM_POSIX
namespace oneflow {
struct SocketMemDesc {
void* mem_ptr;
size_t byte_size;
};
} // namespace oneflow
#endif // OF_PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MEMORY_DESC_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MESSAGE_H_
#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MESSAGE_H_
#include "oneflow/core/common/platform.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/comm_network/comm_network.h"
#ifdef OF_PLATFORM_POSIX
#include <arpa/inet.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include "oneflow/core/lazy/actor/actor_message.h"
#include "oneflow/core/transport/transport_message.h"
namespace oneflow {
#define SOCKET_MSG_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(RequestWrite, request_write) \
OF_PP_MAKE_TUPLE_SEQ(RequestRead, request_read) \
OF_PP_MAKE_TUPLE_SEQ(Actor, actor) \
OF_PP_MAKE_TUPLE_SEQ(Transport, transport)
enum class SocketMsgType {
#define MAKE_ENTRY(x, y) k##x,
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ)
#undef MAKE_ENTRY
};
struct RequestWriteMsg {
void* src_token;
int64_t dst_machine_id;
void* dst_token;
void* read_id;
};
struct RequestReadMsg {
void* src_token;
void* dst_token;
void* read_id;
};
struct SocketMsg {
SocketMsgType msg_type;
union {
#define MAKE_ENTRY(x, y) x##Msg y##_msg;
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ)
#undef MAKE_ENTRY
};
};
using CallBackList = std::list<std::function<void()>>;
} // namespace oneflow
#endif // OF_PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_MESSAGE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef __linux__
#include "oneflow/core/comm_network/epoll/socket_read_helper.h"
#include "oneflow/core/lazy/actor/actor_message_bus.h"
#include "oneflow/core/comm_network/epoll/epoll_comm_network.h"
#include "oneflow/core/transport/transport.h"
#include <netinet/tcp.h>
namespace oneflow {
SocketReadHelper::~SocketReadHelper() {
// do nothing
}
SocketReadHelper::SocketReadHelper(int sockfd) {
sockfd_ = sockfd;
SwitchToMsgHeadReadHandle();
}
void SocketReadHelper::NotifyMeSocketReadable() { ReadUntilSocketNotReadable(); }
void SocketReadHelper::SwitchToMsgHeadReadHandle() {
cur_read_handle_ = &SocketReadHelper::MsgHeadReadHandle;
read_ptr_ = reinterpret_cast<char*>(&cur_msg_);
read_size_ = sizeof(cur_msg_);
}
void SocketReadHelper::ReadUntilSocketNotReadable() {
while ((this->*cur_read_handle_)()) {}
}
bool SocketReadHelper::MsgHeadReadHandle() {
return DoCurRead(&SocketReadHelper::SetStatusWhenMsgHeadDone);
}
bool SocketReadHelper::MsgBodyReadHandle() {
return DoCurRead(&SocketReadHelper::SetStatusWhenMsgBodyDone);
}
bool SocketReadHelper::DoCurRead(void (SocketReadHelper::*set_cur_read_done)()) {
ssize_t n = read(sockfd_, read_ptr_, read_size_);
const int val = 1;
PCHECK(setsockopt(sockfd_, IPPROTO_TCP, TCP_QUICKACK, (char*)&val, sizeof(int)) == 0);
if (n == read_size_) {
(this->*set_cur_read_done)();
return true;
} else if (n >= 0) {
read_ptr_ += n;
read_size_ -= n;
return true;
} else {
CHECK_EQ(n, -1);
PCHECK(errno == EAGAIN || errno == EWOULDBLOCK);
return false;
}
}
void SocketReadHelper::SetStatusWhenMsgHeadDone() {
switch (cur_msg_.msg_type) {
#define MAKE_ENTRY(x, y) \
case SocketMsgType::k##x: SetStatusWhen##x##MsgHeadDone(); break;
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ);
#undef MAKE_ENTRY
default: UNIMPLEMENTED();
}
}
void SocketReadHelper::SetStatusWhenMsgBodyDone() {
if (cur_msg_.msg_type == SocketMsgType::kRequestRead) {
Singleton<EpollCommNet>::Get()->ReadDone(cur_msg_.request_read_msg.read_id);
}
SwitchToMsgHeadReadHandle();
}
void SocketReadHelper::SetStatusWhenRequestWriteMsgHeadDone() {
SocketMsg msg_to_send;
msg_to_send.msg_type = SocketMsgType::kRequestRead;
msg_to_send.request_read_msg.src_token = cur_msg_.request_write_msg.src_token;
msg_to_send.request_read_msg.dst_token = cur_msg_.request_write_msg.dst_token;
msg_to_send.request_read_msg.read_id = cur_msg_.request_write_msg.read_id;
Singleton<EpollCommNet>::Get()->SendSocketMsg(cur_msg_.request_write_msg.dst_machine_id,
msg_to_send);
SwitchToMsgHeadReadHandle();
}
void SocketReadHelper::SetStatusWhenRequestReadMsgHeadDone() {
auto mem_desc = static_cast<const SocketMemDesc*>(cur_msg_.request_read_msg.dst_token);
read_ptr_ = reinterpret_cast<char*>(mem_desc->mem_ptr);
read_size_ = mem_desc->byte_size;
cur_read_handle_ = &SocketReadHelper::MsgBodyReadHandle;
}
void SocketReadHelper::SetStatusWhenActorMsgHeadDone() {
Singleton<ActorMsgBus>::Get()->SendMsgWithoutCommNet(cur_msg_.actor_msg);
SwitchToMsgHeadReadHandle();
}
void SocketReadHelper::SetStatusWhenTransportMsgHeadDone() {
Singleton<Transport>::Get()->EnqueueTransportMsg(cur_msg_.transport_msg);
SwitchToMsgHeadReadHandle();
}
} // namespace oneflow
#endif // __linux__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment