Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
...@@ -49,8 +49,9 @@ Maybe<one::Tensor> GetIdentity(const std::shared_ptr<one::Tensor>& tensor, Symbo ...@@ -49,8 +49,9 @@ Maybe<one::Tensor> GetIdentity(const std::shared_ptr<one::Tensor>& tensor, Symbo
// reset sbp if parallel_num == 1 and reset transport_token // reset sbp if parallel_num == 1 and reset transport_token
const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor()); const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor());
const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list, return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype())); *tensor->shape(), tensor->dtype(),
/* sync_data */ false, /*copy=*/true));
} }
COMMAND(RegisterBoxingFunction("identity", DECORATE(&RawCheckIdentity, ThreadLocalCachedCopiable), COMMAND(RegisterBoxingFunction("identity", DECORATE(&RawCheckIdentity, ThreadLocalCachedCopiable),
......
...@@ -67,9 +67,9 @@ Maybe<one::Tensor> Naive1ToP(const std::shared_ptr<one::Tensor>& tensor, Symbol< ...@@ -67,9 +67,9 @@ Maybe<one::Tensor> Naive1ToP(const std::shared_ptr<one::Tensor>& tensor, Symbol<
local_tensor = JUST(one::functional::Constant(*tensor->shape(), 0, tensor->dtype(), local_tensor = JUST(one::functional::Constant(*tensor->shape(), 0, tensor->dtype(),
JUST(Device::New(device_type)))); JUST(Device::New(device_type))));
} }
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), return JUST(one::functional::LocalToGlobal(
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),
tensor->dtype())); tensor->dtype(), /* sync_data */ false, /*copy=*/true));
} }
COMMAND(RegisterBoxingFunction("naive-1-to-p", CheckNaive1ToP, &Naive1ToP)); COMMAND(RegisterBoxingFunction("naive-1-to-p", CheckNaive1ToP, &Naive1ToP));
......
...@@ -52,9 +52,9 @@ Maybe<one::Tensor> NaiveBTo1(const std::shared_ptr<one::Tensor>& tensor, Symbol< ...@@ -52,9 +52,9 @@ Maybe<one::Tensor> NaiveBTo1(const std::shared_ptr<one::Tensor>& tensor, Symbol<
<< *JUST(PlacementToString(in->placement())) << ")"; << *JUST(PlacementToString(in->placement())) << ")";
std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor()); std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), return JUST(one::functional::LocalToGlobal(
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),
tensor->dtype())); tensor->dtype(), /* sync_data */ false, /*copy=*/true));
} }
COMMAND(RegisterBoxingFunction("naive-b-to-1", CheckNaiveBTo1, &NaiveBTo1)); COMMAND(RegisterBoxingFunction("naive-b-to-1", CheckNaiveBTo1, &NaiveBTo1));
......
...@@ -74,8 +74,9 @@ Maybe<one::Tensor> NaiveBToS(const std::shared_ptr<one::Tensor>& tensor, Symbol< ...@@ -74,8 +74,9 @@ Maybe<one::Tensor> NaiveBToS(const std::shared_ptr<one::Tensor>& tensor, Symbol<
} }
} }
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list, return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype())); *tensor->shape(), tensor->dtype(),
/* sync_data */ false, /*copy=*/false));
} }
static constexpr auto* NaiveBToSWithAutoConvert = static constexpr auto* NaiveBToSWithAutoConvert =
......
...@@ -74,8 +74,9 @@ Maybe<one::Tensor> NaivePToB(const std::shared_ptr<one::Tensor>& tensor, Symbol< ...@@ -74,8 +74,9 @@ Maybe<one::Tensor> NaivePToB(const std::shared_ptr<one::Tensor>& tensor, Symbol<
} }
const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list, return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype())); *tensor->shape(), tensor->dtype(),
/* sync_data */ false, /*copy=*/false));
} }
static constexpr auto* NaivePToBWithAutoConvert = static constexpr auto* NaivePToBWithAutoConvert =
......
...@@ -73,8 +73,9 @@ Maybe<one::Tensor> NaivePToS(const std::shared_ptr<one::Tensor>& tensor, Symbol< ...@@ -73,8 +73,9 @@ Maybe<one::Tensor> NaivePToS(const std::shared_ptr<one::Tensor>& tensor, Symbol<
} }
} }
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list, return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype())); *tensor->shape(), tensor->dtype(),
/* sync_data */ true, /*copy=*/false));
} }
static constexpr auto* NaivePToSWithAutoConvert = static constexpr auto* NaivePToSWithAutoConvert =
......
...@@ -73,8 +73,9 @@ Maybe<one::Tensor> NaiveSToB(const std::shared_ptr<one::Tensor>& tensor, Symbol< ...@@ -73,8 +73,9 @@ Maybe<one::Tensor> NaiveSToB(const std::shared_ptr<one::Tensor>& tensor, Symbol<
} }
const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list, return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype())); *tensor->shape(), tensor->dtype(),
/* sync_data */ false, /*copy=*/false));
} }
static constexpr auto* NaiveSToBWithAutoConvert = static constexpr auto* NaiveSToBWithAutoConvert =
......
...@@ -73,8 +73,9 @@ Maybe<one::Tensor> NaiveSToP(const std::shared_ptr<one::Tensor>& tensor, Symbol< ...@@ -73,8 +73,9 @@ Maybe<one::Tensor> NaiveSToP(const std::shared_ptr<one::Tensor>& tensor, Symbol<
} }
const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list, return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype())); *tensor->shape(), tensor->dtype(),
/* sync_data */ false, /*copy=*/false));
} }
static constexpr auto* NaiveSToPWithAutoConvert = static constexpr auto* NaiveSToPWithAutoConvert =
......
...@@ -71,8 +71,9 @@ Maybe<one::Tensor> NaiveSToS(const std::shared_ptr<one::Tensor>& tensor, Symbol< ...@@ -71,8 +71,9 @@ Maybe<one::Tensor> NaiveSToS(const std::shared_ptr<one::Tensor>& tensor, Symbol<
} }
} }
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *out_sbp_list, return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *out_sbp_list,
*tensor->shape(), tensor->dtype())); *tensor->shape(), tensor->dtype(),
/* sync_data */ false, /*copy=*/false));
} }
static constexpr auto* NaiveSToSWithAutoConvert = static constexpr auto* NaiveSToSWithAutoConvert =
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace {
bool IsSplitSbp(Symbol<SbpParallel> sbp_parallel) { return sbp_parallel->has_split_parallel(); }
Maybe<void> RawCheckNcclP2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(NdSbpIsAllPartialSum(*in->nd_sbp()));
CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_EQ_OR_RETURN(in->placement()->device_type(), DeviceType::kCUDA);
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
static constexpr auto* CheckNcclP2B = DECORATE(&RawCheckNcclP2B, ThreadLocalCachedCopiable);
Maybe<void> RawCheckNcclP2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(NdSbpIsAllPartialSum(*in->nd_sbp()));
CHECK_OR_RETURN(NdSbpIsAllSplit(*out->nd_sbp(), 0));
CHECK_GT_OR_RETURN(logical_shape.NumAxes(), 0);
CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0);
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_EQ_OR_RETURN(in->placement()->device_type(), DeviceType::kCUDA);
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
static constexpr auto* CheckNcclP2S = DECORATE(&RawCheckNcclP2S, ThreadLocalCachedCopiable);
Maybe<void> RawCheckNcclS2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(NdSbpIsAllSplit(*in->nd_sbp(), 0));
CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));
CHECK_GT_OR_RETURN(logical_shape.NumAxes(), 0);
CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0);
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_EQ_OR_RETURN(in->placement()->device_type(), DeviceType::kCUDA);
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
static constexpr auto* CheckNcclS2B = DECORATE(&RawCheckNcclS2B, ThreadLocalCachedCopiable);
Maybe<void> RawCheckNcclS2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(IsSplitSbp(in->nd_sbp()->sbp_parallel(0)));
CHECK_OR_RETURN(IsSplitSbp(out->nd_sbp()->sbp_parallel(0)));
CHECK_NE_OR_RETURN(in->nd_sbp()->sbp_parallel(0).split_parallel().axis(),
out->nd_sbp()->sbp_parallel(0).split_parallel().axis());
int64_t in_split_axis = in->nd_sbp()->sbp_parallel(0).split_parallel().axis();
int64_t out_split_axis = out->nd_sbp()->sbp_parallel(0).split_parallel().axis();
CHECK_GT_OR_RETURN(logical_shape.NumAxes(), in_split_axis);
CHECK_GT_OR_RETURN(logical_shape.NumAxes(), out_split_axis);
CHECK_OR_RETURN(logical_shape.At(in_split_axis) % in->placement()->parallel_num() == 0);
CHECK_OR_RETURN(logical_shape.At(out_split_axis) % in->placement()->parallel_num() == 0);
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_EQ_OR_RETURN(in->placement()->device_type(), DeviceType::kCUDA);
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
static constexpr auto* CheckNcclS2S = DECORATE(&RawCheckNcclS2S, ThreadLocalCachedCopiable);
} // namespace
Maybe<one::Tensor> NcclP2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()); // NOLINT(maybe-need-error-msg)
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement()); // NOLINT(maybe-need-error-msg)
return JUST(one::functional::ConsistentAllReduce(tensor));
}
Maybe<one::Tensor> NcclP2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()); // NOLINT(maybe-need-error-msg)
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement()); // NOLINT(maybe-need-error-msg)
return JUST(one::functional::ConsistentReduceScatter(tensor, "sum"));
}
Maybe<one::Tensor> NcclS2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()); // NOLINT(maybe-need-error-msg)
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement()); // NOLINT(maybe-need-error-msg)
return JUST(one::functional::ConsistentAllGather(tensor));
}
Maybe<one::Tensor> NcclS2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()); // NOLINT(maybe-need-error-msg)
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement()); // NOLINT(maybe-need-error-msg)
return JUST(one::functional::ConsistentS2S(tensor, *JUST(GetSbpList(out->nd_sbp()))));
}
COMMAND(RegisterBoxingFunction("nccl-p-to-b", CheckNcclP2B, &NcclP2B));
COMMAND(RegisterBoxingFunction("nccl-p-to-s", CheckNcclP2S, &NcclP2S));
COMMAND(RegisterBoxingFunction("nccl-s-to-b", CheckNcclS2B, &NcclS2B));
COMMAND(RegisterBoxingFunction("nccl-s-to-s", CheckNcclS2S, &NcclS2S));
} // namespace oneflow
...@@ -27,7 +27,7 @@ namespace oneflow { ...@@ -27,7 +27,7 @@ namespace oneflow {
namespace { namespace {
Maybe<std::tuple<Symbol<PlacedNdSbp>, Symbol<PlacedNdSbp>>> RawInOutPlacedNdSbpDimReduce( Maybe<std::tuple<Symbol<PlacedNdSbp>, Symbol<PlacedNdSbp>>> RawInOutPlacedNdSbpDimReduce(
Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) { Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out, const Shape& logical_shape) {
// reduce hierarchy // reduce hierarchy
ParallelDesc reduced_in_placement = *in->placement(); ParallelDesc reduced_in_placement = *in->placement();
ParallelDesc reduced_out_placement = *out->placement(); ParallelDesc reduced_out_placement = *out->placement();
...@@ -35,14 +35,14 @@ Maybe<std::tuple<Symbol<PlacedNdSbp>, Symbol<PlacedNdSbp>>> RawInOutPlacedNdSbpD ...@@ -35,14 +35,14 @@ Maybe<std::tuple<Symbol<PlacedNdSbp>, Symbol<PlacedNdSbp>>> RawInOutPlacedNdSbpD
NdSbp reduced_out_nd_sbp; NdSbp reduced_out_nd_sbp;
InOutParallelDimReduce(*in->placement(), *out->placement(), *in->nd_sbp(), *out->nd_sbp(), InOutParallelDimReduce(*in->placement(), *out->placement(), *in->nd_sbp(), *out->nd_sbp(),
&reduced_in_placement, &reduced_out_placement, &reduced_in_nd_sbp, &reduced_in_placement, &reduced_out_placement, &reduced_in_nd_sbp,
&reduced_out_nd_sbp); &reduced_out_nd_sbp, logical_shape);
return std::make_tuple( return std::make_tuple(
JUST(PlacedNdSbp::New(SymbolOf(reduced_in_nd_sbp), SymbolOf(reduced_in_placement))), JUST(PlacedNdSbp::New(SymbolOf(reduced_in_nd_sbp), SymbolOf(reduced_in_placement))),
JUST(PlacedNdSbp::New(SymbolOf(reduced_out_nd_sbp), SymbolOf(reduced_out_placement)))); JUST(PlacedNdSbp::New(SymbolOf(reduced_out_nd_sbp), SymbolOf(reduced_out_placement))));
} }
constexpr auto* InOutPlacedNdSbpDimReduce = constexpr auto* InOutPlacedNdSbpDimReduce =
DECORATE(&RawInOutPlacedNdSbpDimReduce, ThreadLocalCached); DECORATE(&RawInOutPlacedNdSbpDimReduce, ThreadLocalCachedCopiable);
// NOLINTBEGIN(maybe-need-error-msg) // NOLINTBEGIN(maybe-need-error-msg)
Maybe<void> RawCheckParallelDimReduce(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out, Maybe<void> RawCheckParallelDimReduce(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
...@@ -51,7 +51,7 @@ Maybe<void> RawCheckParallelDimReduce(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp ...@@ -51,7 +51,7 @@ Maybe<void> RawCheckParallelDimReduce(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp
CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());
Symbol<PlacedNdSbp> reduced_in; Symbol<PlacedNdSbp> reduced_in;
Symbol<PlacedNdSbp> reduced_out; Symbol<PlacedNdSbp> reduced_out;
std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out)); std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out, logical_shape));
for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num(); for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num();
++in_parallel_id) { ++in_parallel_id) {
...@@ -102,13 +102,13 @@ Maybe<one::Tensor> ParallelDimReduce(const std::shared_ptr<one::Tensor>& tensor, ...@@ -102,13 +102,13 @@ Maybe<one::Tensor> ParallelDimReduce(const std::shared_ptr<one::Tensor>& tensor,
Symbol<PlacedNdSbp> reduced_in; Symbol<PlacedNdSbp> reduced_in;
Symbol<PlacedNdSbp> reduced_out; Symbol<PlacedNdSbp> reduced_out;
std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out)); std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out, *tensor->shape()));
const std::shared_ptr<one::Tensor>& local_tensor = JUST(tensor->cur_rank_phy_tensor()); const std::shared_ptr<one::Tensor>& local_tensor = JUST(tensor->cur_rank_phy_tensor());
std::shared_ptr<one::Tensor> reduced_in_tensor = JUST(one::functional::LocalToConsistent( std::shared_ptr<one::Tensor> reduced_in_tensor = JUST(one::functional::LocalToGlobal(
local_tensor, reduced_in->placement(), *JUST(GetSbpList(reduced_in->nd_sbp())), local_tensor, reduced_in->placement(), *JUST(GetSbpList(reduced_in->nd_sbp())),
*tensor->shape(), tensor->dtype())); *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false));
const auto& boxing_interpreter = const auto& boxing_interpreter =
JUST(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter( JUST(Singleton<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(
...@@ -124,9 +124,9 @@ Maybe<one::Tensor> ParallelDimReduce(const std::shared_ptr<one::Tensor>& tensor, ...@@ -124,9 +124,9 @@ Maybe<one::Tensor> ParallelDimReduce(const std::shared_ptr<one::Tensor>& tensor,
const std::shared_ptr<one::Tensor>& reduced_out_local_tensor = const std::shared_ptr<one::Tensor>& reduced_out_local_tensor =
JUST(reduced_out_tensor->cur_rank_phy_tensor()); JUST(reduced_out_tensor->cur_rank_phy_tensor());
return JUST(one::functional::LocalToConsistent(reduced_out_local_tensor, out->placement(), return JUST(one::functional::LocalToGlobal(
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), reduced_out_local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())),
tensor->dtype())); *tensor->shape(), tensor->dtype(), /* sync_data */ false, /*copy=*/false));
} }
COMMAND(RegisterBoxingFunction("nd-sbp-dim-reduce", CheckParallelDimReduce, &ParallelDimReduce)); COMMAND(RegisterBoxingFunction("nd-sbp-dim-reduce", CheckParallelDimReduce, &ParallelDimReduce));
......
...@@ -19,6 +19,7 @@ limitations under the License. ...@@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/decorator.h" #include "oneflow/core/common/decorator.h"
#include "oneflow/user/kernels/communicate_util.h"
namespace oneflow { namespace oneflow {
...@@ -31,6 +32,7 @@ Maybe<void> RawCheckNaiveOneToOne(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> ou ...@@ -31,6 +32,7 @@ Maybe<void> RawCheckNaiveOneToOne(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> ou
CHECK_EQ_OR_RETURN(out->placement()->parallel_num(), 1); CHECK_EQ_OR_RETURN(out->placement()->parallel_num(), 1);
CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag()); CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());
CHECK_OR_RETURN(in->placement() != out->placement()); CHECK_OR_RETURN(in->placement() != out->placement());
CHECK_OR_RETURN(IsSendAndRecvRegistered(in->placement()->device_type())); // NOLINT
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
// NOLINTEND(maybe-need-error-msg) // NOLINTEND(maybe-need-error-msg)
...@@ -56,7 +58,9 @@ Maybe<one::Tensor> NaiveOneToOne(const std::shared_ptr<one::Tensor>& tensor, Sym ...@@ -56,7 +58,9 @@ Maybe<one::Tensor> NaiveOneToOne(const std::shared_ptr<one::Tensor>& tensor, Sym
int64_t src = JUST(tensor_placement->MachineId4ParallelId(0)); int64_t src = JUST(tensor_placement->MachineId4ParallelId(0));
int64_t dst = JUST(out->placement()->MachineId4ParallelId(0)); int64_t dst = JUST(out->placement()->MachineId4ParallelId(0));
bool copy = true;
if (src != dst) { if (src != dst) {
copy = false;
if (GlobalProcessCtx::Rank() == src) { if (GlobalProcessCtx::Rank() == src) {
JUST(one::functional::Send(local_tensor, dst, /* send_meta */ false)); JUST(one::functional::Send(local_tensor, dst, /* send_meta */ false));
} }
...@@ -65,9 +69,9 @@ Maybe<one::Tensor> NaiveOneToOne(const std::shared_ptr<one::Tensor>& tensor, Sym ...@@ -65,9 +69,9 @@ Maybe<one::Tensor> NaiveOneToOne(const std::shared_ptr<one::Tensor>& tensor, Sym
JUST(local_tensor->device()), NullOpt)); JUST(local_tensor->device()), NullOpt));
} }
} }
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), return JUST(one::functional::LocalToGlobal(
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),
tensor->dtype())); tensor->dtype(), /* sync_data */ false, /*copy=*/copy));
} }
COMMAND(RegisterBoxingFunction("naive-1-to-1", CheckNaiveOneToOne, &NaiveOneToOne)); COMMAND(RegisterBoxingFunction("naive-1-to-1", CheckNaiveOneToOne, &NaiveOneToOne));
......
...@@ -18,6 +18,7 @@ limitations under the License. ...@@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/boxing/eager_boxing_logger.h" #include "oneflow/core/boxing/eager_boxing_logger.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h" #include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/user/kernels/communicate_util.h"
namespace oneflow { namespace oneflow {
...@@ -26,10 +27,7 @@ namespace private_details { ...@@ -26,10 +27,7 @@ namespace private_details {
Maybe<one::Tensor> PreprocessInputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor, Maybe<one::Tensor> PreprocessInputTensor4SliceBoxing(const std::shared_ptr<one::Tensor>& tensor,
const std::string& log_prefix) { const std::string& log_prefix) {
const auto& tensor_placement = JUST(tensor->parallel_desc()); const auto& tensor_placement = JUST(tensor->parallel_desc());
if (tensor_placement->device_type() == DeviceType::kCPU if (IsSendAndRecvRegistered(tensor_placement->device_type())) { return tensor; }
|| tensor_placement->device_type() == DeviceType::kCUDA) {
return tensor;
}
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
Symbol<ParallelDesc> new_placement = JUST(ReplaceDeviceType(tensor_placement, DeviceType::kCPU)); Symbol<ParallelDesc> new_placement = JUST(ReplaceDeviceType(tensor_placement, DeviceType::kCPU));
......
...@@ -29,30 +29,17 @@ namespace oneflow { ...@@ -29,30 +29,17 @@ namespace oneflow {
namespace { namespace {
Maybe<one::OpExpr> MakeToConsistentOpExpr() { Maybe<one::Tensor> ReinterpterGlobalTensor(const std::shared_ptr<one::Tensor>& tensor,
std::shared_ptr<one::OpExpr> op_expr = const Shape& shape, Symbol<ParallelDesc> parallel_desc,
JUST(one::CastToConsistentOpExpr::New(*JUST(UniqueStr("cast_to_consistent")))); Symbol<NdSbp> nd_sbp) {
return op_expr;
}
static constexpr auto* GetLocalToConsistentOpExpr =
DECORATE(&MakeToConsistentOpExpr, ThreadLocalCachedCopiable);
Maybe<one::Tensor> ReinterpterConsistentTensor(const std::shared_ptr<one::Tensor>& tensor,
const Shape& shape,
Symbol<ParallelDesc> parallel_desc,
Symbol<NdSbp> nd_sbp) {
const auto& op = JUST(GetLocalToConsistentOpExpr());
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<DataType>("dtype", tensor->dtype()->data_type()));
const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc)); const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));
std::shared_ptr<Shape> pyhsical_shape = std::shared_ptr<Shape> pyhsical_shape =
JUST(GetPhysicalShape(shape, *nd_sbp, *parallel_desc, JUST(*parallel_id))); JUST(GetPhysicalShape(shape, *nd_sbp, *parallel_desc, JUST(*parallel_id)));
std::shared_ptr<one::Tensor> x = JUST(tensor->cur_rank_phy_tensor()); std::shared_ptr<one::Tensor> x = JUST(tensor->cur_rank_phy_tensor());
if (*x->shape() != *pyhsical_shape) { x = JUST(one::functional::Reshape(x, *pyhsical_shape)); } if (*x->shape() != *pyhsical_shape) { x = JUST(one::functional::Reshape(x, *pyhsical_shape)); }
return JUST(one::OpInterpUtil::Dispatch<one::Tensor>( return JUST(one::functional::LocalToGlobal(x, parallel_desc, *JUST(GetSbpList(nd_sbp)), shape,
*op, {x}, one::OpExprInterpContext(attrs, parallel_desc, nd_sbp))); tensor->dtype(), /* sync_data */ false,
/*copy=*/false));
} }
Maybe<one::Tensor> Apply1DBoxing(const std::shared_ptr<one::Tensor>& input, Symbol<NdSbp> in_nd_sbp, Maybe<one::Tensor> Apply1DBoxing(const std::shared_ptr<one::Tensor>& input, Symbol<NdSbp> in_nd_sbp,
...@@ -81,7 +68,7 @@ Maybe<void> RawCheckSymmetricAcyclicNdSbpBoxing(Symbol<PlacedNdSbp> in, Symbol<P ...@@ -81,7 +68,7 @@ Maybe<void> RawCheckSymmetricAcyclicNdSbpBoxing(Symbol<PlacedNdSbp> in, Symbol<P
// NOLINTEND(maybe-need-error-msg) // NOLINTEND(maybe-need-error-msg)
static constexpr auto* CheckSymmetricAcyclicNdSbpBoxing = static constexpr auto* CheckSymmetricAcyclicNdSbpBoxing =
DECORATE(&RawCheckSymmetricAcyclicNdSbpBoxing, ThreadLocalCopiable); DECORATE(&RawCheckSymmetricAcyclicNdSbpBoxing, ThreadLocalCachedCopiable);
} // namespace } // namespace
...@@ -101,27 +88,26 @@ Maybe<one::Tensor> SymmetricAcyclicNdSbpBoxing(const std::shared_ptr<one::Tensor ...@@ -101,27 +88,26 @@ Maybe<one::Tensor> SymmetricAcyclicNdSbpBoxing(const std::shared_ptr<one::Tensor
std::shared_ptr<one::Tensor> output; std::shared_ptr<one::Tensor> output;
const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc)); const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_parallel_desc));
if (out_parallel_id->has_value()) { if (out_parallel_id->has_value()) {
const auto& tensor_meta = JUST(input->consistent_tensor_meta()); const auto& tensor_meta = JUST(input->global_tensor_meta());
const auto& naive_transformations = const auto& naive_transformations =
JUST(DecomposeIntoNaiveTransformations(tensor_meta, out_nd_sbp)); JUST(DecomposeIntoNaiveTransformations(tensor_meta, out_nd_sbp));
std::shared_ptr<one::Tensor> tensor = input; std::shared_ptr<one::Tensor> tensor = input;
for (const auto& naive_transformation : *naive_transformations) { for (const auto& naive_transformation : *naive_transformations) {
const auto& sub_tensor_meta = naive_transformation.consistent_tensor_meta; const auto& sub_tensor_meta = naive_transformation.global_tensor_meta;
tensor = JUST(ReinterpterConsistentTensor(tensor, sub_tensor_meta->shape(), tensor = JUST(ReinterpterGlobalTensor(tensor, sub_tensor_meta->shape(),
sub_tensor_meta->parallel_desc(), sub_tensor_meta->parallel_desc(),
sub_tensor_meta->nd_sbp())); sub_tensor_meta->nd_sbp()));
tensor = tensor =
JUST(Apply1DBoxing(tensor, sub_tensor_meta->nd_sbp(), naive_transformation.dst_nd_sbp, JUST(Apply1DBoxing(tensor, sub_tensor_meta->nd_sbp(), naive_transformation.dst_nd_sbp,
sub_tensor_meta->parallel_desc(), sub_tensor_meta->parallel_desc())); sub_tensor_meta->parallel_desc(), sub_tensor_meta->parallel_desc()));
} }
output = output = JUST(ReinterpterGlobalTensor(tensor, *input->shape(), out_parallel_desc, out_nd_sbp));
JUST(ReinterpterConsistentTensor(tensor, *input->shape(), out_parallel_desc, out_nd_sbp));
} else { } else {
one::ConsistentTensorMeta tensor_meta(input->shape(), input->dtype()->data_type(), out_nd_sbp, one::GlobalTensorMeta tensor_meta(*input->shape(), input->dtype()->data_type(), out_nd_sbp,
out_parallel_desc); out_parallel_desc);
const auto& tensor_impl = JUST( const auto& tensor_impl =
one::EagerConsistentTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false)); JUST(one::EagerGlobalTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false));
output = std::make_shared<one::ConsistentTensor>(tensor_impl); output = std::make_shared<one::GlobalTensor>(tensor_impl);
} }
return output; return output;
} }
......
...@@ -63,9 +63,9 @@ Maybe<one::Tensor> SymmetricBToP(const std::shared_ptr<one::Tensor>& tensor, Sym ...@@ -63,9 +63,9 @@ Maybe<one::Tensor> SymmetricBToP(const std::shared_ptr<one::Tensor>& tensor, Sym
} else { } else {
local_tensor = JUST(one::functional::ZerosLike(local_tensor)); local_tensor = JUST(one::functional::ZerosLike(local_tensor));
} }
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), return JUST(one::functional::LocalToGlobal(
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),
tensor->dtype())); tensor->dtype(), /* sync_data */ false, /*copy=*/true));
} }
COMMAND(RegisterBoxingFunction("symmetric-b-to-p", CheckSymmetricBToP, &SymmetricBToP)); COMMAND(RegisterBoxingFunction("symmetric-b-to-p", CheckSymmetricBToP, &SymmetricBToP));
......
...@@ -45,6 +45,8 @@ Maybe<void> RawCheckSymmetricB2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out ...@@ -45,6 +45,8 @@ Maybe<void> RawCheckSymmetricB2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out
CHECK_OR_RETURN(IsSplitSbp(SymbolOf(out->nd_sbp()->sbp_parallel(0)))); CHECK_OR_RETURN(IsSplitSbp(SymbolOf(out->nd_sbp()->sbp_parallel(0))));
CHECK_OR_RETURN(in->placement() == out->placement()); CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
// NOLINTEND(maybe-need-error-msg) // NOLINTEND(maybe-need-error-msg)
...@@ -92,9 +94,9 @@ Maybe<one::Tensor> SymmetricB2S(const std::shared_ptr<one::Tensor>& tensor, Symb ...@@ -92,9 +94,9 @@ Maybe<one::Tensor> SymmetricB2S(const std::shared_ptr<one::Tensor>& tensor, Symb
/*enable_view_slice=*/false)); /*enable_view_slice=*/false));
} }
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), return JUST(one::functional::LocalToGlobal(
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(), local_tensor, out->placement(), *JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),
tensor->dtype())); tensor->dtype(), /* sync_data */ false, /*copy=*/false));
} }
COMMAND(RegisterBoxingFunction("symmetric-b-to-s", CheckSymmetricB2S, &SymmetricB2S)); COMMAND(RegisterBoxingFunction("symmetric-b-to-s", CheckSymmetricB2S, &SymmetricB2S));
......
...@@ -70,8 +70,9 @@ Maybe<one::Tensor> UnflattenHierarchy(const std::shared_ptr<one::Tensor>& tensor ...@@ -70,8 +70,9 @@ Maybe<one::Tensor> UnflattenHierarchy(const std::shared_ptr<one::Tensor>& tensor
<< *JUST(PlacementToString(in->placement())) << ")"; << *JUST(PlacementToString(in->placement())) << ")";
const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor()); const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor());
const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list, return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype())); *tensor->shape(), tensor->dtype(),
/* sync_data */ false, /*copy=*/true));
} }
COMMAND(RegisterBoxingFunction("unflatten-hierarchy", CheckUnflattenHierarchy, COMMAND(RegisterBoxingFunction("unflatten-hierarchy", CheckUnflattenHierarchy,
......
...@@ -24,11 +24,7 @@ limitations under the License. ...@@ -24,11 +24,7 @@ limitations under the License.
#include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/thread/thread_manager.h" #include "oneflow/core/thread/thread_manager.h"
#include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/job/eager_nccl_comm_manager.h"
#ifdef WITH_ROCM
#include "oneflow/core/ep/rocm/cuda_stream.h"
#else
#include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/cuda/cuda_stream.h"
#endif
#include "oneflow/core/common/constant.h" #include "oneflow/core/common/constant.h"
namespace oneflow { namespace oneflow {
...@@ -51,275 +47,8 @@ Maybe<void> InitBroadcastRankHeap(std::vector<int64_t>* ranks, const ParallelDes ...@@ -51,275 +47,8 @@ Maybe<void> InitBroadcastRankHeap(std::vector<int64_t>* ranks, const ParallelDes
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
int64_t RingDecrease(int64_t n, int64_t size) { return (n - 1 + size) % size; }
int64_t RingIncrease(int64_t n, int64_t size) { return (n + 1 + size) % size; }
template<typename T>
void VecAdd(size_t size, T* out, const T* in0, const T* in1) {
size_t thread_num = Singleton<ThreadPool>::Get()->thread_num();
BalancedSplitter bs(size, thread_num);
MultiThreadLoop(thread_num, [&](size_t thread_idx) {
size_t end = bs.At(thread_idx).end();
for (size_t i = bs.At(thread_idx).begin(); i < end; ++i) { out[i] = in0[i] + in1[i]; }
});
}
} // namespace } // namespace
template<typename T, ReduceType reduce_type>
struct DtypeAllReduce;
template<typename T>
struct DtypeAllReduce<T, kSum> {
static Maybe<void> Call(const void* void_in, void* void_out, size_t elem_cnt,
Symbol<ParallelDesc> parallel_desc) {
int64_t parallel_num = parallel_desc->parallel_num();
if (parallel_num == 1) {
if (void_in != void_out) { std::memcpy(void_out, void_in, elem_cnt * sizeof(T)); }
return Maybe<void>::Ok();
}
const T* in = reinterpret_cast<const T*>(void_in);
T* out = reinterpret_cast<T*>(void_out);
BalancedSplitter bs(elem_cnt, parallel_num);
auto recv_buffer = std::make_unique<T[]>(bs.At(0).size());
Optional<int64_t> parallel_id;
JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));
const auto& rank_group = JUST(RankGroup::New(parallel_desc));
TransportToken transport_token =
JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
for (int64_t i = 0, part_id = JUST(parallel_id); i < parallel_num - 1;
++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const T* send_ptr = nullptr;
if (i == 0) {
send_ptr = &in[bs.At(send_part_id).begin()];
} else {
send_ptr = &out[bs.At(send_part_id).begin()];
}
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
T* recv_ptr = recv_buffer.get();
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
const T* cur_in = &in[bs.At(recv_part_id).begin()];
T* cur_out = &out[bs.At(recv_part_id).begin()];
if (recv_size > 0) { VecAdd(recv_size, cur_out, cur_in, recv_ptr); }
}
for (int64_t i = 0, part_id = RingIncrease(JUST(parallel_id), parallel_num);
i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const T* send_ptr = &out[bs.At(send_part_id).begin()];
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
T* recv_ptr = &out[bs.At(recv_part_id).begin()];
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
}
return Maybe<void>::Ok();
}
};
#define MAKE_ALL_REDUCE_ENTRY(func_name, T, reduce_type) func_name<T, reduce_type>::Call
DEFINE_STATIC_SWITCH_FUNC(Maybe<void>, DtypeAllReduce, MAKE_ALL_REDUCE_ENTRY,
MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), CCL_REDUCE_TYPE_CTRV_SEQ);
#undef MAKE_ALL_REDUCE_ENTRY
template<>
Maybe<void> AllReduce<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream) {
return SwitchDtypeAllReduce(SwitchCase(dtype, reduce_type), in, out, elem_cnt, parallel_desc);
}
template<typename T, ReduceType reduce_type>
struct DtypeReduceScatter;
template<typename T>
struct DtypeReduceScatter<T, kSum> {
static Maybe<void> Call(const void* void_in, void* void_out, size_t elem_cnt,
Symbol<ParallelDesc> parallel_desc) {
int64_t parallel_num = parallel_desc->parallel_num();
if (parallel_num == 1) {
if (void_in != void_out) { std::memcpy(void_out, void_in, elem_cnt * sizeof(T)); }
return Maybe<void>::Ok();
}
const T* in = reinterpret_cast<const T*>(void_in);
T* out = reinterpret_cast<T*>(void_out);
BalancedSplitter bs(elem_cnt * parallel_num, parallel_num);
const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));
CHECK_OR_RETURN(opt_parallel_id->has_value());
int64_t parallel_id = JUST(*opt_parallel_id);
auto recv_buffer = std::make_unique<T[]>(bs.At(0).size());
const auto& rank_group = JUST(RankGroup::New(parallel_desc));
TransportToken transport_token =
JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
for (int64_t i = 0, part_id = RingDecrease(parallel_id, parallel_num); i < parallel_num - 1;
++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const T* send_ptr = nullptr;
if (i == 0) {
send_ptr = &in[bs.At(send_part_id).begin()];
} else {
send_ptr = out;
}
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
T* recv_ptr = recv_buffer.get();
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
const T* cur_in = &in[bs.At(recv_part_id).begin()];
if (recv_size > 0) { VecAdd(recv_size, out, cur_in, recv_ptr); }
}
return Maybe<void>::Ok();
}
};
#define MAKE_REDUCE_SCATTER_ENTRY(func_name, T, reduce_type) func_name<T, reduce_type>::Call
DEFINE_STATIC_SWITCH_FUNC(Maybe<void>, DtypeReduceScatter, MAKE_REDUCE_SCATTER_ENTRY,
MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), CCL_REDUCE_TYPE_CTRV_SEQ);
#undef MAKE_REDUCE_SCATTER_ENTRY
template<>
Maybe<void> ReduceScatter<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt,
DataType dtype, ReduceType reduce_type,
Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream) {
return SwitchDtypeReduceScatter(SwitchCase(dtype, reduce_type), in, out, elem_cnt, parallel_desc);
}
template<>
Maybe<void> AllGather<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt, DataType dtype,
Symbol<ParallelDesc> parallel_desc, ep::Stream* stream) {
int64_t parallel_num = parallel_desc->parallel_num();
if (parallel_num == 1) {
if (in != out) { std::memcpy(out, in, elem_cnt * GetSizeOfDataType(dtype)); }
return Maybe<void>::Ok();
}
char* char_out = reinterpret_cast<char*>(out);
size_t chunk_size = elem_cnt * GetSizeOfDataType(dtype);
BalancedSplitter bs(chunk_size * parallel_num, parallel_num);
const auto& opt_parallel_id = JUST(GetParallelId4CurrentProcessCtx(parallel_desc));
CHECK_OR_RETURN(opt_parallel_id->has_value());
const auto& rank_group = JUST(RankGroup::New(parallel_desc));
TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
int64_t parallel_id = JUST(*opt_parallel_id);
// In-place operation will happen if in == out + parallel_id * chunk_size
if (in != &char_out[parallel_id * chunk_size]) {
memcpy(&char_out[parallel_id * chunk_size], in, chunk_size);
}
for (int64_t i = 0, part_id = parallel_id; i < parallel_num - 1;
++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const void* send_ptr = &char_out[bs.At(send_part_id).begin()];
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
void* recv_ptr = &char_out[bs.At(recv_part_id).begin()];
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<void*>(send_ptr);
*size = send_size;
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size;
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
}
return Maybe<void>::Ok();
}
template<>
Maybe<void> Broadcast<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt, DataType dtype,
int64_t root, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream) {
CHECK_EQ_OR_RETURN(parallel_desc->device_type(), DeviceType::kCPU);
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
const auto& transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
return CpuBroadcast(in, out, buffer_size, root, parallel_desc, transport_token);
}
Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root, Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root,
Symbol<ParallelDesc> parallel_desc, Symbol<ParallelDesc> parallel_desc,
const TransportToken& transport_token) { const TransportToken& transport_token) {
...@@ -351,171 +80,7 @@ Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t ...@@ -351,171 +80,7 @@ Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
template<typename T, ReduceType reduce_type> Maybe<void> CpuSend(const void* in, size_t buffer_size, int64_t dst) {
struct DtypeReduce;
template<typename T>
struct DtypeReduce<T, kSum> {
static Maybe<void> Call(const void* void_in, void* void_out, size_t elem_cnt, int64_t root,
Symbol<ParallelDesc> parallel_desc) {
const T* in = reinterpret_cast<const T*>(void_in);
T* out = reinterpret_cast<T*>(void_out);
int64_t parallel_num = parallel_desc->parallel_num();
BalancedSplitter bs(elem_cnt, parallel_num);
size_t size = root == GlobalProcessCtx::Rank() && void_in != void_out ? 0 : bs.At(0).size();
T* tmp_out = nullptr;
// void_out is only used on rank root and ignored for other ranks.
auto tmp_out_buffer = std::make_unique<T[]>(size);
int64_t parallel_id_of_root =
JUST(parallel_desc->ParallelId4MachineDeviceId(root, GlobalProcessCtx::LocalRank(root)));
if (root == GlobalProcessCtx::Rank() && void_in != void_out) {
tmp_out = &reinterpret_cast<T*>(void_out)[bs.At(parallel_id_of_root).begin()];
} else {
tmp_out = tmp_out_buffer.get();
}
auto recv_buffer = std::make_unique<T[]>(bs.At(0).size());
Optional<int64_t> parallel_id;
JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));
const auto& rank_group = JUST(RankGroup::New(parallel_desc));
TransportToken transport_token =
JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
for (int64_t i = 0, part_id = RingDecrease(JUST(parallel_id), parallel_num);
i < parallel_num - 1; ++i, part_id = RingDecrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
const T* send_ptr = nullptr;
if (i == 0) {
send_ptr = &in[bs.At(send_part_id).begin()];
} else {
send_ptr = tmp_out;
}
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = RingDecrease(part_id, parallel_num);
T* recv_ptr = recv_buffer.get();
size_t recv_size = bs.At(recv_part_id).size();
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
if (send_size > 0) {
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx));
}
if (recv_size > 0) {
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx));
}
JUST(ctx.WaitDone());
const T* cur_in = &in[bs.At(recv_part_id).begin()];
if (recv_size > 0) { VecAdd(recv_size, tmp_out, cur_in, recv_ptr); }
}
if (root == GlobalProcessCtx::Rank() && void_in == void_out) {
memcpy(&out[bs.At(parallel_id_of_root).begin()], tmp_out,
bs.At(parallel_id_of_root).size() * sizeof(T));
}
for (int64_t i = 0, part_id = RingIncrease(parallel_id_of_root, parallel_num);
i < parallel_num - 1; ++i, part_id = RingIncrease(part_id, parallel_num)) {
int64_t send_part_id = part_id;
int64_t src_rank = JUST(parallel_desc->MachineId4ParallelId(send_part_id));
const T* send_ptr = tmp_out;
size_t send_size = bs.At(send_part_id).size();
int64_t recv_part_id = part_id;
T* recv_ptr = &out[bs.At(recv_part_id).begin()];
size_t recv_size = bs.At(recv_part_id).size();
if (send_size > 0 && src_rank == GlobalProcessCtx::Rank()) {
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = const_cast<T*>(send_ptr);
*size = send_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
UNIMPLEMENTED_THEN_RETURN();
});
JUST(TransportUtil::SendDataToRank(root, transport_token, &ctx));
JUST(ctx.WaitDone());
}
if (recv_size > 0 && root == GlobalProcessCtx::Rank()) {
NaiveAsyncTransportCtx ctx(
transport_token,
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
UNIMPLEMENTED_THEN_RETURN();
},
[&](void** buffer, std::size_t* size, std::function<void()>* Cb) -> Maybe<void> {
*buffer = recv_ptr;
*size = recv_size * sizeof(T);
*Cb = [] {};
return Maybe<void>::Ok();
});
JUST(TransportUtil::ReceiveDataFromRank(src_rank, transport_token, &ctx));
JUST(ctx.WaitDone());
}
}
return Maybe<void>::Ok();
}
};
#define MAKE_REDUCE_ENTRY(func_name, T, reduce_type) func_name<T, reduce_type>::Call
DEFINE_STATIC_SWITCH_FUNC(Maybe<void>, DtypeReduce, MAKE_REDUCE_ENTRY,
MAKE_DATA_TYPE_CTRV_SEQ(POD_DATA_TYPE_SEQ), CCL_REDUCE_TYPE_CTRV_SEQ);
#undef MAKE_REDUCE_ENTRY
template<>
Maybe<void> Reduce<DeviceType::kCPU>(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, int64_t root,
Symbol<ParallelDesc> parallel_desc, ep::Stream* stream) {
return SwitchDtypeReduce(SwitchCase(dtype, reduce_type), in, out, elem_cnt, root, parallel_desc);
}
#ifdef WITH_CUDA
std::pair<ncclComm_t, int64_t> RawGetNcclCommAndPeerNcclRank(int64_t peer_process_id) {
std::set<std::pair<int64_t, int64_t>> device_set;
const int64_t& rank = GlobalProcessCtx::Rank();
const int64_t peer_nccl_rank = (peer_process_id > rank) ? 1 : 0;
device_set.emplace(rank, GlobalProcessCtx::LocalRank());
device_set.emplace(peer_process_id, GlobalProcessCtx::LocalRank(peer_process_id));
return {CHECK_NOTNULL(Singleton<EagerNcclCommMgr>::Get())->GetCommForDevice(device_set),
peer_nccl_rank};
}
auto* GetNcclCommAndPeerNcclRank = DECORATE(&RawGetNcclCommAndPeerNcclRank, ThreadLocal);
#endif
#ifdef WITH_ROCM
std::pair<ncclComm_t, int64_t> RawGetNcclCommAndPeerNcclRank(int64_t peer_process_id) {
std::set<std::pair<int64_t, int64_t>> device_set;
const int64_t& rank = GlobalProcessCtx::Rank();
const int64_t peer_nccl_rank = (peer_process_id > rank) ? 1 : 0;
device_set.emplace(rank, GlobalProcessCtx::LocalRank());
device_set.emplace(peer_process_id, GlobalProcessCtx::LocalRank(peer_process_id));
return {CHECK_NOTNULL(Singleton<EagerNcclCommMgr>::Get())->GetCommForDevice(device_set),
peer_nccl_rank};
}
auto* GetNcclCommAndPeerNcclRank = DECORATE(&RawGetNcclCommAndPeerNcclRank, ThreadLocal);
#endif
template<>
Maybe<void> Send<DeviceType::kCPU>(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
NaiveAsyncTransportCtx transport_ctx( NaiveAsyncTransportCtx transport_ctx(
transport_token, transport_token,
...@@ -533,41 +98,7 @@ Maybe<void> Send<DeviceType::kCPU>(const void* in, size_t elem_cnt, DataType dty ...@@ -533,41 +98,7 @@ Maybe<void> Send<DeviceType::kCPU>(const void* in, size_t elem_cnt, DataType dty
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
#ifdef WITH_CUDA Maybe<void> CpuRecv(void* out, size_t buffer_size, int64_t src) {
template<>
Maybe<void> Send<DeviceType::kCUDA>(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,
ep::Stream* stream) {
#if NCCL_VERSION_CODE >= 2700
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst);
OF_NCCL_CHECK_OR_RETURN(ncclSend(in, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
#else
UNIMPLEMENTED_THEN_RETURN() << "GPU send is only supported when nccl version >= 2.7"
#endif
}
#endif
#ifdef WITH_ROCM
template<>
Maybe<void> Send<DeviceType::kCUDA>(const void* in, size_t elem_cnt, DataType dtype, int64_t dst,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst);
OF_NCCL_CHECK_OR_RETURN(ncclSend(in, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
}
#endif
template<>
Maybe<void> Recv<DeviceType::kCPU>(void* out, size_t elem_cnt, DataType dtype, int64_t src,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
size_t buffer_size = elem_cnt * GetSizeOfDataType(dtype);
TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData));
NaiveAsyncTransportCtx transport_ctx( NaiveAsyncTransportCtx transport_ctx(
transport_token, transport_token,
...@@ -585,35 +116,5 @@ Maybe<void> Recv<DeviceType::kCPU>(void* out, size_t elem_cnt, DataType dtype, i ...@@ -585,35 +116,5 @@ Maybe<void> Recv<DeviceType::kCPU>(void* out, size_t elem_cnt, DataType dtype, i
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
#ifdef WITH_CUDA
template<>
Maybe<void> Recv<DeviceType::kCUDA>(void* out, size_t elem_cnt, DataType dtype, int64_t src,
ep::Stream* stream) {
#if NCCL_VERSION_CODE >= 2700
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src);
OF_NCCL_CHECK_OR_RETURN(ncclRecv(out, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
#else
UNIMPLEMENTED_THEN_RETURN() << "GPU recv is only supported when nccl version >= 2.7"
#endif
}
#endif
#ifdef WITH_ROCM
template<>
Maybe<void> Recv<DeviceType::kCUDA>(void* out, size_t elem_cnt, DataType dtype, int64_t src,
ep::Stream* stream) {
CHECK_OR_RETURN(IsPODDataType(dtype));
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(src);
OF_NCCL_CHECK_OR_RETURN(ncclRecv(out, elem_cnt, GetNcclDataType(dtype), comm_and_peer_rank.second,
comm_and_peer_rank.first,
stream->As<ep::CudaStream>()->cuda_stream()));
return Maybe<void>::Ok();
}
#endif
} // namespace ccl } // namespace ccl
} // namespace oneflow } // namespace oneflow
...@@ -24,55 +24,15 @@ limitations under the License. ...@@ -24,55 +24,15 @@ limitations under the License.
namespace oneflow { namespace oneflow {
class DeviceCtx;
class ParallelDesc; class ParallelDesc;
class TransportToken; class TransportToken;
// collective communication library // collective communication library
namespace ccl { namespace ccl {
#define CCL_REDUCE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(kSum) Maybe<void> CpuSend(const void* in, size_t buffer_size, int64_t dst);
enum ReduceType { Maybe<void> CpuRecv(void* out, size_t buffer_size, int64_t src);
kInvalidReduceFunctorType = 0,
#define DEFINE_REDUCE_TYPE_ENUM_VALUE(enum_value) enum_value,
OF_PP_FOR_EACH_TUPLE(DEFINE_REDUCE_TYPE_ENUM_VALUE, CCL_REDUCE_TYPE_SEQ)
#undef DEFINE_REDUCE_TYPE_ENUM_VALUE
kReduceTypeSize
};
#define CCL_REDUCE_TYPE_CTRV_SEQ \
MAKE_TYPED_CTRV_SEQ(ReduceType, \
OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, CCL_REDUCE_TYPE_SEQ))
template<DeviceType device_type>
Maybe<void> AllReduce(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> ReduceScatter(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> AllGather(const void* in, void* out, size_t elem_cnt, DataType dtype,
Symbol<ParallelDesc> parallel_desc, ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> Send(const void* in, size_t elem_cnt, DataType dtype, int64_t dst, ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> Recv(void* out, size_t elem_cnt, DataType dtype, int64_t src, ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> Broadcast(const void* in, void* out, size_t elem_cnt, DataType dtype, int64_t root,
Symbol<ParallelDesc> parallel_desc, ep::Stream* stream);
template<DeviceType device_type>
Maybe<void> Reduce(const void* in, void* out, size_t elem_cnt, DataType dtype,
ReduceType reduce_type, int64_t root, Symbol<ParallelDesc> parallel_desc,
ep::Stream* stream);
Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root, Maybe<void> CpuBroadcast(const void* in, void* out, size_t buffer_size, int64_t root,
Symbol<ParallelDesc> parallel_desc, const TransportToken& transport_token); Symbol<ParallelDesc> parallel_desc, const TransportToken& transport_token);
......
...@@ -71,9 +71,11 @@ IBVerbsCommNet::~IBVerbsCommNet() { ...@@ -71,9 +71,11 @@ IBVerbsCommNet::~IBVerbsCommNet() {
for (IBVerbsQP* qp : qp_vec_) { for (IBVerbsQP* qp : qp_vec_) {
if (qp) { delete qp; } if (qp) { delete qp; }
} }
CHECK_EQ(ibv::wrapper.ibv_destroy_cq(cq_), 0); PCHECK(ibv::wrapper.ibv_destroy_cq(cq_) == 0);
CHECK_EQ(ibv::wrapper.ibv_dealloc_pd(pd_), 0); PCHECK(ibv::wrapper.ibv_dealloc_pd(pd_) == 0);
CHECK_EQ(ibv::wrapper.ibv_close_device(context_), 0); CHECK_EQ(ibv::wrapper.ibv_close_device(context_), 0)
<< "Error, failed to close the IB device "
<< ibv::wrapper.ibv_get_device_name(context_->device);
} }
void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) { void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
...@@ -127,20 +129,21 @@ IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT ...@@ -127,20 +129,21 @@ IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT
CHECK(device != nullptr) << "No IB device match " << user_device; CHECK(device != nullptr) << "No IB device match " << user_device;
} }
context_ = ibv::wrapper.ibv_open_device(device); context_ = ibv::wrapper.ibv_open_device(device);
CHECK(context_); CHECK(context_ != NULL) << "Error, failed to open the IB device "
<< ibv::wrapper.ibv_get_device_name(device);
ibv::wrapper.ibv_free_device_list(device_list); ibv::wrapper.ibv_free_device_list(device_list);
pd_ = ibv::wrapper.ibv_alloc_pd(context_); pd_ = ibv::wrapper.ibv_alloc_pd(context_);
CHECK(pd_); CHECK(pd_) << "Error, ibv_alloc_pd() allocates a Protection Domain (PD) failed";
ibv_device_attr device_attr{}; ibv_device_attr device_attr{};
CHECK_EQ(ibv::wrapper.ibv_query_device(context_, &device_attr), 0); PCHECK(ibv::wrapper.ibv_query_device(context_, &device_attr) == 0);
cq_ = ibv::wrapper.ibv_create_cq(context_, device_attr.max_cqe, nullptr, nullptr, 0); cq_ = ibv::wrapper.ibv_create_cq(context_, device_attr.max_cqe, nullptr, nullptr, 0);
CHECK(cq_); PCHECK(cq_);
ibv_port_attr port_attr{}; ibv_port_attr port_attr{};
const uint8_t port = user_port == 0 ? 1 : user_port; const uint8_t port = user_port == 0 ? 1 : user_port;
CHECK_EQ(ibv::wrapper.ibv_query_port_wrap(context_, port, &port_attr), 0); PCHECK(ibv::wrapper.ibv_query_port_wrap(context_, port, &port_attr) == 0);
ibv_gid gid{}; ibv_gid gid{};
const int64_t gid_index = ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_GID_INDEX", 0); const int64_t gid_index = ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_GID_INDEX", 0);
CHECK_EQ(ibv::wrapper.ibv_query_gid(context_, port, gid_index, &gid), 0); PCHECK(ibv::wrapper.ibv_query_gid(context_, port, gid_index, &gid) == 0);
VLOG(1) << "Using IB device " << device->name << " port " << static_cast<int32_t>(port) VLOG(1) << "Using IB device " << device->name << " port " << static_cast<int32_t>(port)
<< " gid index " << gid_index; << " gid index " << gid_index;
int64_t this_machine_id = GlobalProcessCtx::Rank(); int64_t this_machine_id = GlobalProcessCtx::Rank();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment