Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
namespace {
struct TFPoolCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
size_t input_index = 0;
size_t output_index = 0;
std::string data_format;
std::string padding;
std::vector<int32_t> padding_before;
std::vector<int32_t> padding_after;
std::vector<int32_t> pool_size;
std::vector<int32_t> strides;
bool ceil_mode = false;
};
class TFPoolNdGrad : public OpExprGradFunction<TFPoolCaptureState> {
public:
virtual ~TFPoolNdGrad() = default;
using OpExprGradFunction<TFPoolCaptureState>::Init;
Maybe<void> Init(const OpExpr& op, const std::string& mode);
Maybe<void> Capture(TFPoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const TFPoolCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
std::string mode_;
AttrMap base_attrs_;
};
Maybe<void> TFPoolNdGrad::Init(const OpExpr& op, const std::string& mode) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
mode_ = mode;
return Maybe<void>::Ok();
}
Maybe<void> TFPoolNdGrad::Capture(TFPoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));
ctx->output_index = ctx->SaveTensorForBackward(outputs.at(0));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->padding = JUST(composed_attrs.GetAttr<std::string>("padding"));
ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before"));
ctx->padding_after = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_after"));
ctx->pool_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("pool_size"));
ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("strides"));
ctx->ceil_mode = JUST(composed_attrs.GetAttr<bool>("ceil_mode"));
return Maybe<void>::Ok();
}
Maybe<void> TFPoolNdGrad::Apply(const TFPoolCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
int32_t ndims = ctx->pool_size.size();
const auto& input = ctx->SavedTensors().at(ctx->input_index);
const auto& output = ctx->SavedTensors().at(ctx->output_index);
in_grads->resize(1);
(*in_grads)[0] = JUST(functional::TFPoolNdGrad(
input, output, out_grads[0], mode_, ndims, ctx->data_format, ctx->padding,
ctx->padding_before, ctx->padding_after, ctx->pool_size, ctx->strides, ctx->ceil_mode));
return Maybe<void>::Ok();
}
} // namespace
class TFMaxPoolNdGrad final : public TFPoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return TFPoolNdGrad::Init(op, "tf_max"); }
};
REGISTER_OP_EXPR_GRAD_FUNCTION("tf_max_pool_1d", TFMaxPoolNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("tf_max_pool_2d", TFMaxPoolNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("tf_max_pool_3d", TFMaxPoolNdGrad);
class TFAvgPoolNdGrad final : public TFPoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return TFPoolNdGrad::Init(op, "tf_avg"); }
};
REGISTER_OP_EXPR_GRAD_FUNCTION("tf_avg_pool_1d", TFAvgPoolNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("tf_avg_pool_2d", TFAvgPoolNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("tf_avg_pool_3d", TFAvgPoolNdGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
struct ToContiguousCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};
class ToContiguous : public OpExprGradFunction<ToContiguousCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(ToContiguousCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs[0]->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ToContiguousCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) { (*in_grads)[0] = out_grads[0]; }
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("to_contiguous", ToContiguous);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct TransposeCaptureState : public AutoGradCaptureState {
std::vector<int32_t> perm;
bool requires_grad;
};
class Transpose : public OpExprGradFunction<TransposeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(TransposeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const TransposeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Transpose::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Transpose::Capture(TransposeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->perm = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("perm"));
return Maybe<void>::Ok();
}
Maybe<void> Transpose::Apply(const TransposeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
std::vector<int32_t> grad_perm;
grad_perm.resize(ctx->perm.size());
FOR_RANGE(int32_t, i, 0, ctx->perm.size()) { grad_perm.at(ctx->perm.at(i)) = i; }
in_grads->at(0) = JUST(functional::Transpose(out_grads.at(0), grad_perm));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("transpose", Transpose);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct TrilCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
int64_t diagonal = 0;
};
class Tril : public OpExprGradFunction<TrilCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(TrilCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const TrilCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Tril::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Tril::Capture(TrilCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->diagonal = JUST(composed_attrs.GetAttr<int64_t>("diagonal"));
return Maybe<void>::Ok();
}
Maybe<void> Tril::Apply(const TrilCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
in_grads->at(0) = JUST(functional::Tril(out_grads.at(0), ctx->diagonal));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("tril", Tril);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct TriuCaptureState : public AutoGradCaptureState {
bool requires_grad;
int64_t diagonal;
};
class Triu : public OpExprGradFunction<TriuCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(TriuCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const TriuCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Triu::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Triu::Capture(TriuCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->diagonal = JUST(composed_attrs.GetAttr<int64_t>("diagonal"));
return Maybe<void>::Ok();
}
Maybe<void> Triu::Apply(const TriuCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
in_grads->at(0) = JUST(functional::Triu(out_grads.at(0), ctx->diagonal));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("triu", Triu);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
enum class ReduceMode : int32_t {
kMin = 0,
kMax = 1,
};
struct ReduceDeviceCaptureState : public AutoGradCaptureState {
std::vector<int32_t> axis;
bool requires_grad = false;
size_t mask_index = -1;
size_t count_index = -1;
};
template<ReduceMode mode>
class ReduceDevice : public OpExprGradFunction<ReduceDeviceCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(ReduceDeviceCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("axis"));
ctx->mask_index = ctx->SaveTensorForBackward(outputs.at(1)); // mask
ctx->count_index = ctx->SaveTensorForBackward(outputs.at(2)); // count
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ReduceDeviceCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 3); // NOLINT(maybe-need-error-msg)
const auto& mask = ctx->SavedTensors().at(ctx->mask_index);
const auto& count = ctx->SavedTensors().at(ctx->count_index);
in_grads->resize(1);
if (mode == ReduceMode::kMin) {
in_grads->at(0) =
JUST(functional::ReduceMinDeviceStageGrad(out_grads.at(0), mask, count, ctx->axis));
} else {
in_grads->at(0) =
JUST(functional::ReduceMaxDeviceStageGrad(out_grads.at(0), mask, count, ctx->axis));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_min_device_stage", ReduceDevice<ReduceMode::kMin>);
REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_max_device_stage", ReduceDevice<ReduceMode::kMax>);
struct ReduceGlobalCaptureState : public AutoGradCaptureState {
std::vector<int32_t> axis;
bool requires_grad = false;
bool keepdims = false;
size_t mask_index = -1;
size_t device_count_index = -1;
};
template<ReduceMode mode>
class ReduceGlobal : public OpExprGradFunction<ReduceGlobalCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(ReduceGlobalCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("axis"));
ctx->keepdims = JUST(composed_attrs.GetAttr<bool>("keepdims"));
ctx->mask_index = ctx->SaveTensorForBackward(outputs.at(1)); // mask
ctx->device_count_index = ctx->SaveTensorForBackward(inputs.at(1)); // device_count
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ReduceGlobalCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
const auto& mask = ctx->SavedTensors().at(ctx->mask_index);
const auto& device_count = ctx->SavedTensors().at(ctx->device_count_index);
in_grads->resize(2);
if (mode == ReduceMode::kMin) {
in_grads->at(0) = JUST(functional::ReduceMinGlobalStageGrad(
out_grads.at(0), mask, device_count, ctx->axis, ctx->keepdims));
} else {
in_grads->at(0) = JUST(functional::ReduceMaxGlobalStageGrad(
out_grads.at(0), mask, device_count, ctx->axis, ctx->keepdims));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_min_global_stage", ReduceGlobal<ReduceMode::kMin>);
REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_max_global_stage", ReduceGlobal<ReduceMode::kMax>);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct UnfoldInterpState : public AutoGradCaptureState {
bool requires_grad = true;
std::string data_format = "channels_first";
std::vector<int32_t> output_size;
std::vector<int32_t> kernel_size;
std::vector<int32_t> dilation_rate;
std::vector<int32_t> padding;
std::vector<int32_t> strides;
};
class Unfold : public OpExprGradFunction<UnfoldInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(UnfoldInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const UnfoldInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Unfold::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Unfold::Capture(UnfoldInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
std::vector<int32_t> out_shape(2);
const std::shared_ptr<Tensor>& x = inputs.at(0);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation_rate"));
ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding"));
ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("strides"));
// Only support 4-d Tensor Input.
for (int i = 0; i < 2; i++) { out_shape.at(i) = (x->shape()->At(i + 2)); }
ctx->output_size = out_shape;
return Maybe<void>::Ok();
}
Maybe<void> Unfold::Apply(const UnfoldInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
in_grads->at(0) =
JUST(functional::Fold(out_grads.at(0), ctx->data_format, ctx->output_size, ctx->kernel_size,
ctx->dilation_rate, ctx->padding, ctx->strides));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("unfold", Unfold);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct UnfoldTensorCaptureState : public AutoGradCaptureState {
int32_t dimension = -1;
int32_t size = -1;
int32_t step = -1;
bool requires_grad = false;
};
class UnfoldTensor : public OpExprGradFunction<UnfoldTensorCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(UnfoldTensorCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const UnfoldTensorCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
std::shared_ptr<OpExpr> grad_op_;
};
Maybe<void> UnfoldTensor::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> UnfoldTensor::Capture(UnfoldTensorCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dimension = JUST(composed_attrs.GetAttr<int32_t>("dimension"));
ctx->size = JUST(composed_attrs.GetAttr<int32_t>("size"));
ctx->step = JUST(composed_attrs.GetAttr<int32_t>("step"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> UnfoldTensor::Apply(const UnfoldTensorCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& in = ctx->SavedTensors().at(0);
in_grads->at(0) =
JUST(functional::UnfoldTensorGrad(out_grads.at(0), in, ctx->dimension, ctx->size, ctx->step));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("unfold_tensor", UnfoldTensor);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/job/lazy_mode.h"
namespace oneflow {
namespace one {
struct UnsqueezeCaptureState : public AutoGradCaptureState {
bool requires_grad;
Shape shape;
};
class Unsqueeze : public OpExprGradFunction<UnsqueezeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(UnsqueezeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const UnsqueezeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Unsqueeze::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Unsqueeze::Capture(UnsqueezeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
if (LazyMode::is_enabled()) {
ctx->SaveTensorForBackward(inputs.at(0));
} else {
ctx->shape = *(inputs.at(0)->shape());
}
return Maybe<void>::Ok();
}
Maybe<void> Unsqueeze::Apply(const UnsqueezeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (LazyMode::is_enabled()) {
const auto& like = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), like));
} else {
in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), ctx->shape));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("expand_dims", Unsqueeze);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct UpsampleCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
double height_scale = 0.0;
double width_scale = 0.0;
float align_corners;
std::string data_format;
std::string interpolation;
};
class Upsample : public OpExprGradFunction<UpsampleCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(UpsampleCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const UpsampleCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
std::shared_ptr<OpExpr> grad_op_;
};
Maybe<void> Upsample::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Upsample::Capture(UpsampleCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->interpolation = JUST(composed_attrs.GetAttr<std::string>("interpolation"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Upsample::Apply(const UpsampleCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale,
ctx->align_corners, ctx->data_format, ctx->interpolation));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("upsample", Upsample);
struct UpsampleNearest2DCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
double height_scale = 0.0;
double width_scale = 0.0;
std::vector<int64_t> output_size;
std::string data_format;
};
class UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UpsampleNearest2DCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UpsampleNearest2DCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale,
ctx->output_size, ctx->data_format));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_nearest_2d", UpsampleNearest2D);
struct UpsampleBilinear2DCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
double height_scale = 0.0;
double width_scale = 0.0;
bool align_corners;
std::vector<int64_t> output_size;
std::string data_format;
};
class UpsampleBilinear2D : public OpExprGradFunction<UpsampleBilinear2DCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UpsampleBilinear2DCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UpsampleBilinear2DCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleBilinear2DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale,
ctx->align_corners, ctx->output_size, ctx->data_format));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_bilinear_2d", UpsampleBilinear2D);
struct UpsampleLinear1DCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
double scale_factor = 0.0;
bool align_corners;
std::vector<int64_t> output_size;
std::string data_format;
};
class UpsampleLinear1D : public OpExprGradFunction<UpsampleLinear1DCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UpsampleLinear1DCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->scale_factor = JUST(composed_attrs.GetAttr<double>("scale_factor"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UpsampleLinear1DCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleLinear1DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->scale_factor, ctx->align_corners,
ctx->output_size, ctx->data_format));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_linear_1d", UpsampleLinear1D);
struct UpsampleNearest1DCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
double scale_factor = 0.0;
std::vector<int64_t> output_size;
std::string data_format;
};
class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UpsampleNearest1DCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->scale_factor = JUST(composed_attrs.GetAttr<double>("scale_factor"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UpsampleNearest1DCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(
functional::UpsampleNearest1DGrad(JUST(oneflow::VectorAt(out_grads, 0)), x,
ctx->scale_factor, ctx->output_size, ctx->data_format));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_nearest_1d", UpsampleNearest1D);
struct UpsampleBicubic2DCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
double height_scale = 0.0;
double width_scale = 0.0;
bool align_corners;
std::vector<int64_t> output_size;
std::string data_format;
};
class UpsampleBicubic2D : public OpExprGradFunction<UpsampleBicubic2DCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UpsampleBicubic2DCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UpsampleBicubic2DCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleBicubic2DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale,
ctx->align_corners, ctx->output_size, ctx->data_format));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_bicubic_2d", UpsampleBicubic2D);
struct UpsampleNearest3DCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
double depth_scale = 0.0;
double height_scale = 0.0;
double width_scale = 0.0;
std::vector<int64_t> output_size;
std::string data_format;
};
class UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UpsampleNearest3DCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->depth_scale = JUST(composed_attrs.GetAttr<double>("depth_scale"));
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UpsampleNearest3DCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale,
ctx->width_scale, ctx->output_size, ctx->data_format));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_nearest_3d", UpsampleNearest3D);
struct UpsampleTrilinear3DCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
double depth_scale = 0.0;
double height_scale = 0.0;
double width_scale = 0.0;
bool align_corners;
std::vector<int64_t> output_size;
std::string data_format;
};
class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UpsampleTrilinear3DCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->depth_scale = JUST(composed_attrs.GetAttr<double>("depth_scale"));
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UpsampleTrilinear3DCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleTrilinear3DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale,
ctx->width_scale, ctx->align_corners, ctx->output_size, ctx->data_format));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_trilinear_3d", UpsampleTrilinear3D);
} // namespace one
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct VarianceState : public AutoGradCaptureState {
VarianceState() : requires_grad(false), unbiased(true), keepdim(false), axis({}){};
bool requires_grad;
bool unbiased;
bool keepdim;
std::vector<int32_t> axis;
};
class Variance : public OpExprGradFunction<VarianceState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(VarianceState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const VarianceState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Variance::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Variance::Capture(VarianceState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->keepdim = JUST(composed_attrs.GetAttr<bool>("keepdim"));
ctx->unbiased = JUST(composed_attrs.GetAttr<bool>("unbiased"));
ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dim"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Variance::Apply(const VarianceState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
// TODO(): replace it using kernel
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
size_t correction = ctx->unbiased ? 1 : 0;
size_t elem_cnt = 1;
CHECK_OR_RETURN(ctx->axis.size() > 0)
<< Error::RuntimeError() << "The size of the axis must greater than 0, but got "
<< ctx->axis.size();
for (const auto& item : ctx->axis) { elem_cnt *= x->shape()->At(item); }
std::shared_ptr<Tensor> out_grad = out_grads.at(0);
if (ctx->keepdim == false) {
// for broadcast mul
const std::shared_ptr<const Shape>& out_grad_shape = out_grad->shape();
DimVector unsqueeze_vector(out_grad_shape->dim_vec());
for (int i = 0; i < ctx->axis.size(); i++) {
unsqueeze_vector.insert(unsqueeze_vector.begin() + ctx->axis.at(i), 1);
}
Shape unsqueeze_shape(unsqueeze_vector);
CHECK_EQ_OR_RETURN(unsqueeze_shape.elem_cnt(), out_grad_shape->elem_cnt())
<< Error::RuntimeError()
<< "tensor size mismatch, expected tensor to have the same number of elements, but got "
<< unsqueeze_shape.elem_cnt() << " and " << out_grad_shape->elem_cnt()
<< " elements respectively";
out_grad = JUST(functional::Reshape(out_grad, unsqueeze_shape));
}
in_grads->resize(1);
in_grads->at(0) = JUST(functional::Mul(
out_grad,
JUST(functional::ScalarMul(
Scalar(2.0 / (elem_cnt - correction)),
JUST(functional::Sub(x, JUST(functional::ReduceMean(x, ctx->axis, /*keepdim=*/true)),
/*alpha=*/1.0, /*inplace=*/false))))));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("var", Variance);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct WhereCaptureState : public AutoGradCaptureState {
bool requires_grad_x;
bool requires_grad_y;
};
struct WhereScalarCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class Where : public OpExprGradFunction<WhereCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(WhereCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const WhereCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> Where::Init(const OpExpr& op) { return Maybe<void>::Ok(); }
Maybe<void> Where::Capture(WhereCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad_x = inputs.at(1)->requires_grad();
ctx->requires_grad_y = inputs.at(2)->requires_grad();
if ((!ctx->requires_grad_x) && (!ctx->requires_grad_y)) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0)); // condition
ctx->SaveTensorForBackward(inputs.at(1)); // x
ctx->SaveTensorForBackward(inputs.at(2)); // y
return Maybe<void>::Ok();
}
Maybe<void> Where::Apply(const WhereCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if ((!ctx->requires_grad_x) && (!ctx->requires_grad_y)) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& condition = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(1);
const std::shared_ptr<oneflow::one::Tensor>& y = ctx->SavedTensors().at(2);
std::shared_ptr<oneflow::one::Tensor> zero_out = JUST(functional::ZerosLike(x));
in_grads->resize(3);
if (ctx->requires_grad_x) {
auto broad_x_grad = JUST(functional::Where(condition, out_grads.at(0), zero_out));
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(broad_x_grad, x));
}
if (ctx->requires_grad_y) {
auto broad_y_grad = JUST(functional::Where(condition, zero_out, out_grads.at(0)));
in_grads->at(2) = JUST(functional::BroadcastReduceSumLike(broad_y_grad, y));
}
return Maybe<void>::Ok();
}
class WhereScalar : public OpExprGradFunction<WhereScalarCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(WhereScalarCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->requires_grad = inputs.at(1)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(inputs.at(1));
return Maybe<void>::Ok();
}
};
class WhereScalarX : public WhereScalar {
public:
Maybe<void> Apply(const WhereScalarCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& condition = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& y = ctx->SavedTensors().at(1);
std::shared_ptr<oneflow::one::Tensor> zero_out = JUST(functional::ZerosLike(y));
in_grads->resize(2);
auto broad_y_grad = JUST(functional::Where(condition, zero_out, out_grads.at(0)));
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(broad_y_grad, y));
return Maybe<void>::Ok();
}
};
class WhereScalarY : public WhereScalar {
public:
Maybe<void> Apply(const WhereScalarCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& condition = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(1);
std::shared_ptr<oneflow::one::Tensor> zero_out = JUST(functional::ZerosLike(x));
in_grads->resize(2);
auto broad_x_grad = JUST(functional::Where(condition, out_grads.at(0), zero_out));
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(broad_x_grad, x));
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("where", Where);
REGISTER_OP_EXPR_GRAD_FUNCTION("where_scalar_x", WhereScalarX);
REGISTER_OP_EXPR_GRAD_FUNCTION("where_scalar_y", WhereScalarY);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/framework/placement_sbp_util.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/decorator.h"
namespace oneflow {
namespace {
Maybe<void> RawCheckAsymmetricBroadcast(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(NdSbpIsAllBroadcast(*in->nd_sbp()));
CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));
CHECK_OR_RETURN(out->placement()->Bigger(*in->placement())
|| in->placement()->Bigger(*out->placement()));
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
static constexpr auto* CheckAsymmetricBroadcast =
DECORATE(&RawCheckAsymmetricBroadcast, ThreadLocalCachedCopiable);
Maybe<int64_t> CalBroadcastRoot(Symbol<ParallelDesc> src_parallel_desc,
Symbol<ParallelDesc> dst_parallel_desc) {
int64_t machine_id = -1;
int64_t device_id = -1;
for (int64_t mach_id : src_parallel_desc->sorted_machine_ids()) {
bool machine_and_device_id_inited = false;
for (int64_t dev_id : src_parallel_desc->sorted_dev_phy_ids(mach_id)) {
if (dst_parallel_desc->Containing(mach_id, dev_id)) {
machine_id = mach_id;
device_id = dev_id;
machine_and_device_id_inited = true;
break;
}
}
if (machine_and_device_id_inited) { break; }
}
// Always true, if check failed, there is a bug in oneflow needed to be resolved.
CHECK_OR_RETURN(machine_id != -1 && device_id != -1)
<< Error::RuntimeError()
<< "Calculate the intersection of placements "
"failed during execution of asymmetric broadcast,"
<< ", placement_a: " << *JUST(PlacementToString(src_parallel_desc))
<< ", placement_b: " << *JUST(PlacementToString(dst_parallel_desc))
<< "! Please submit an issue in `https://github.com/Oneflow-Inc/oneflow/issues` "
"and we will fix it as soon as possible";
return machine_id;
}
static constexpr auto* CachedGetBroadcastRoot = DECORATE(&CalBroadcastRoot, ThreadLocalCached);
Maybe<one::UserOpExpr> EagerNcclBroadcast(Symbol<ParallelDesc> parallel_desc, int64_t root) {
return one::OpBuilder("eager_nccl_broadcast", *JUST(UniqueStr("eager_nccl_broadcast")))
.Input("in")
.Output("out")
.Attr<std::string>("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf()))
.Attr<int64_t>("root", root)
.Build();
}
static constexpr auto* CachedEagerNcclBroadcast = DECORATE(&EagerNcclBroadcast, ThreadLocalCached);
} // namespace
Maybe<one::Tensor> AsymmetricBroadcast(const std::shared_ptr<one::Tensor>& tensor,
Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {
const auto& in_placement = in->placement();
const auto& out_placement = out->placement();
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in_placement)
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in_placement)) << ")";
std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());
if (out->placement()->Bigger(*in->placement())) {
const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_placement));
if (out_parallel_id->has_value()) {
const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(in_placement));
if (!in_parallel_id->has_value()) {
const std::string& device_type = in_placement->device_tag();
local_tensor =
JUST(one::functional::Empty(*tensor->shape(), tensor->dtype(),
JUST(Device::New(device_type)), /*pin_memory=*/false));
}
const auto& broadcast_group = JUST(GetBroadcastGroup(in_placement, out_placement));
Symbol<ParallelDesc> broadcast_placement_cur_rank =
JUST(MapAt(*broadcast_group, GlobalProcessCtx::Rank()));
int64_t root = JUST(CachedGetBroadcastRoot(in_placement, broadcast_placement_cur_rank));
std::shared_ptr<one::UserOpExpr> op_expr =
JUST(CachedEagerNcclBroadcast(broadcast_placement_cur_rank, root));
local_tensor = JUST(one::OpInterpUtil::Dispatch<one::Tensor>(*op_expr, {local_tensor}));
}
}
return one::functional::LocalToConsistent(local_tensor, out_placement,
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),
tensor->dtype());
}
COMMAND(RegisterBoxingFunction("asymmetric-broadcast", CheckAsymmetricBroadcast,
&AsymmetricBroadcast));
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_
#define ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_
#include <functional>
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/symbol.h"
namespace oneflow {
class PlacedNdSbp;
class BoxingDividor final {
public:
BoxingDividor(const BoxingDividor&) = delete;
BoxingDividor(BoxingDividor&&) = delete;
~BoxingDividor() = default;
using FunctionT =
std::function<Maybe<Symbol<PlacedNdSbp>>(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out)>;
BoxingDividor(const std::string& name, const FunctionT& function)
: name_(name), function_(function) {}
const std::string& name() const { return name_; }
Maybe<Symbol<PlacedNdSbp>> operator()(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) const {
return function_(in, out);
}
private:
std::string name_;
FunctionT function_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/boxing/boxing_dividor_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/placed_nd_sbp.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/job/parallel_desc.h"
namespace oneflow {
namespace {
Maybe<BoxingDividor> RawReplaceInDeviceType(DeviceType device_type) {
return std::make_shared<BoxingDividor>(
"ReplaceInDeviceType",
[device_type](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
const auto& new_placement = JUST(ReplaceDeviceType(in->placement(), device_type));
return PlacedNdSbp::New(in->nd_sbp(), new_placement);
});
}
Maybe<BoxingDividor> RawReplaceOutDeviceType(DeviceType device_type) {
return std::make_shared<BoxingDividor>(
"ReplaceOutDeviceType",
[device_type](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
const auto& new_placement = JUST(ReplaceDeviceType(out->placement(), device_type));
return PlacedNdSbp::New(out->nd_sbp(), new_placement);
});
}
} // namespace
decltype(ReplaceInDeviceType) ReplaceInDeviceType =
DECORATE(&RawReplaceInDeviceType, ThreadLocalCached);
decltype(ReplaceOutDeviceType) ReplaceOutDeviceType =
DECORATE(&RawReplaceOutDeviceType, ThreadLocalCached);
namespace {
Maybe<Symbol<PlacedNdSbp>> RawFlattenHierarchy(Symbol<PlacedNdSbp> placed_nd_sbp) {
CHECK_GE_OR_RETURN(placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0)
<< Error::RuntimeError() << "Invalid nd_sbp with ndim equal 0!";
const auto& first_sbp_parallel = placed_nd_sbp->nd_sbp()->sbp_parallel(0);
for (const auto& sbp_parallel : placed_nd_sbp->nd_sbp()->sbp_parallel()) {
CHECK_OR_RETURN(sbp_parallel == first_sbp_parallel)
<< Error::RuntimeError()
<< "Expected all sbps to be on the same in sbp list during flatten sbps list, but find at "
"least two sbps, "
<< SbpToString(first_sbp_parallel) << " and " << SbpToString(sbp_parallel) << "!";
}
std::vector<Symbol<SbpParallel>> vec{SymbolOf(first_sbp_parallel)};
const auto& flattened_nd_sbp = JUST(GetNdSbp(vec));
ParallelConf flattened_parallel_conf(placed_nd_sbp->placement()->parallel_conf());
flattened_parallel_conf.clear_hierarchy();
const auto& flattened_placement = SymbolOf(ParallelDesc(flattened_parallel_conf));
return JUST(PlacedNdSbp::New(flattened_nd_sbp, flattened_placement));
}
static constexpr auto* FlattenHierarchy = DECORATE(&RawFlattenHierarchy, ThreadLocalCached);
Maybe<BoxingDividor> RawFlattenInHierarchy() {
return std::make_shared<BoxingDividor>(
"FlattenInHierarchy",
[](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
return FlattenHierarchy(in);
});
}
Maybe<Symbol<PlacedNdSbp>> RawUnflattenHierarchy(Symbol<PlacedNdSbp> in_placed_nd_sbp,
Symbol<PlacedNdSbp> out_placed_nd_sbp) {
CHECK_GE_OR_RETURN(in_placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0)
<< Error::RuntimeError() << "Invalid nd_sbp with ndim equal 0!";
CHECK_GE_OR_RETURN(out_placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0)
<< Error::RuntimeError() << "Invalid nd_sbp with ndim equal 0!";
const auto& in_sbp_parallel = in_placed_nd_sbp->nd_sbp()->sbp_parallel(0);
NdSbp unflattened_nd_sbp;
for (int64_t i = 0; i < out_placed_nd_sbp->nd_sbp()->sbp_parallel_size(); ++i) {
unflattened_nd_sbp.mutable_sbp_parallel()->Add()->CopyFrom(in_sbp_parallel);
}
return JUST(PlacedNdSbp::New(SymbolOf(unflattened_nd_sbp), out_placed_nd_sbp->placement()));
}
static constexpr auto* UnflattenHierarchy = DECORATE(&RawUnflattenHierarchy, ThreadLocalCached);
Maybe<BoxingDividor> RawUnflattenInHierarchy() {
return std::make_shared<BoxingDividor>(
"UnflattenInHierarchy",
[](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
return UnflattenHierarchy(in, out);
});
}
Maybe<BoxingDividor> RawUnflattenOutHierarchy() {
return std::make_shared<BoxingDividor>(
"UnflattenOutHierarchy",
[](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
return UnflattenHierarchy(out, in);
});
}
} // namespace
decltype(FlattenInHierarchy) FlattenInHierarchy =
DECORATE(&RawFlattenInHierarchy, ThreadLocalCached);
decltype(UnflattenInHierarchy) UnflattenInHierarchy =
DECORATE(&RawUnflattenInHierarchy, ThreadLocalCached);
decltype(UnflattenOutHierarchy) UnflattenOutHierarchy =
DECORATE(&RawUnflattenOutHierarchy, ThreadLocalCached);
namespace {
Maybe<Symbol<NdSbp>> GetAllPartialSumNdSbp(int64_t ndim) {
NdSbp partial_sum_nd_sbp;
for (int64_t i = 0; i < ndim; ++i) {
partial_sum_nd_sbp.mutable_sbp_parallel()->Add()->mutable_partial_sum_parallel();
}
return SymbolOf(partial_sum_nd_sbp);
}
auto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocalCached);
Maybe<Symbol<PlacedNdSbp>> RawReplaceNdSbpWithPartialSum(Symbol<PlacedNdSbp> placed_nd_sbp) {
Symbol<NdSbp> partial_sum_nd_sbp =
JUST(CachedGetAllPartialSumNdSbp(placed_nd_sbp->nd_sbp()->sbp_parallel_size()));
return JUST(PlacedNdSbp::New(partial_sum_nd_sbp, placed_nd_sbp->placement()));
}
static constexpr auto* ReplaceNdSbpWithPartialSum =
DECORATE(&RawReplaceNdSbpWithPartialSum, ThreadLocalCached);
Maybe<BoxingDividor> RawOutPlacementAndPartialSum() {
return std::make_shared<BoxingDividor>(
"OutPlacementAndPartialSum",
[](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
return ReplaceNdSbpWithPartialSum(out);
});
}
} // namespace
decltype(OutPlacementAndPartialSum) OutPlacementAndPartialSum =
DECORATE(&RawOutPlacementAndPartialSum, ThreadLocalCached);
namespace {
Maybe<Symbol<NdSbp>> GetAllBroadcastNdSbp(int64_t ndim) {
NdSbp broadcast_nd_sbp;
for (int64_t i = 0; i < ndim; ++i) {
broadcast_nd_sbp.mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();
}
return SymbolOf(broadcast_nd_sbp);
}
auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocalCached);
Maybe<Symbol<PlacedNdSbp>> RawReplaceNdSbpWithBroadcast(Symbol<PlacedNdSbp> placed_nd_sbp) {
Symbol<NdSbp> broadcast_nd_sbp =
JUST(CachedGetAllBroadcastNdSbp(placed_nd_sbp->nd_sbp()->sbp_parallel_size()));
return JUST(PlacedNdSbp::New(broadcast_nd_sbp, placed_nd_sbp->placement()));
}
static constexpr auto* ReplaceNdSbpWithBroadcast =
DECORATE(&RawReplaceNdSbpWithBroadcast, ThreadLocalCached);
Maybe<BoxingDividor> RawInPlacementAndBroadcast() {
return std::make_shared<BoxingDividor>(
"InPlacementAndBroadcast",
[](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
return ReplaceNdSbpWithBroadcast(in);
});
}
Maybe<BoxingDividor> RawOutPlacementAndBroadcast() {
return std::make_shared<BoxingDividor>(
"OutPlacementAndBroadcast",
[](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
return ReplaceNdSbpWithBroadcast(out);
});
}
} // namespace
decltype(InPlacementAndBroadcast) InPlacementAndBroadcast =
DECORATE(&RawInPlacementAndBroadcast, ThreadLocalCached);
decltype(OutPlacementAndBroadcast) OutPlacementAndBroadcast =
DECORATE(&RawOutPlacementAndBroadcast, ThreadLocalCached);
namespace {
Maybe<Symbol<NdSbp>> GetSplitNdSbp(int64_t axis) {
NdSbp split_nd_sbp;
split_nd_sbp.mutable_sbp_parallel()->Add()->mutable_split_parallel()->set_axis(axis);
return SymbolOf(split_nd_sbp);
}
auto* CachedGetSplitNdSbp = DECORATE(&GetSplitNdSbp, ThreadLocalCached);
Maybe<BoxingDividor> RawInPlacementAndSplit(int64_t axis) {
return std::make_shared<BoxingDividor>(
"InPlacementAndSplit",
[=](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
Symbol<NdSbp> split_nd_sbp = JUST(CachedGetSplitNdSbp(axis));
return PlacedNdSbp::New(split_nd_sbp, in->placement());
});
}
Maybe<BoxingDividor> RawOutPlacementAndSplit(int64_t axis) {
return std::make_shared<BoxingDividor>(
"OutPlacementAndSplit",
[=](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
Symbol<NdSbp> split_nd_sbp = JUST(CachedGetSplitNdSbp(axis));
return PlacedNdSbp::New(split_nd_sbp, out->placement());
});
}
} // namespace
decltype(InPlacementAndSplit) InPlacementAndSplit =
DECORATE(&RawInPlacementAndSplit, ThreadLocalCached);
decltype(OutPlacementAndSplit) OutPlacementAndSplit =
DECORATE(&RawOutPlacementAndSplit, ThreadLocalCached);
namespace {
Maybe<Symbol<ParallelDesc>> GetFisrtDeviceOfPlacement(Symbol<ParallelDesc> placement) {
ParallelConf parallel_conf;
int64_t machine_id = JUST(placement->MachineId4ParallelId(0));
int64_t device_id = JUST(placement->DeviceId4ParallelId(0));
parallel_conf.set_device_tag(placement->device_tag());
parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":"
+ std::to_string(device_id));
for (int64_t i = 0; i < placement->hierarchy()->NumAxes(); ++i) {
parallel_conf.mutable_hierarchy()->add_dim(1);
}
std::shared_ptr<ParallelDesc> parallel_desc;
JUST(PhysicalRun([&parallel_desc, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {
parallel_desc = JUST(builder->GetParallelDescSymbol(parallel_conf));
return Maybe<void>::Ok();
}));
return SymbolOf(*parallel_desc);
}
Maybe<BoxingDividor> RawInFirstDeviceAndAllBroadcast() {
return std::make_shared<BoxingDividor>(
"InFirstDeviceAndAllBroadcast",
[](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
return PlacedNdSbp::New(JUST(CachedGetAllBroadcastNdSbp(in->nd_sbp()->sbp_parallel_size())),
JUST(GetFisrtDeviceOfPlacement(in->placement())));
});
}
Maybe<BoxingDividor> RawOutFirstDeviceAndAllBroadcast() {
return std::make_shared<BoxingDividor>(
"OutFirstDeviceAndAllBroadcast",
[](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
return PlacedNdSbp::New(
JUST(CachedGetAllBroadcastNdSbp(out->nd_sbp()->sbp_parallel_size())),
JUST(GetFisrtDeviceOfPlacement(out->placement())));
});
}
} // namespace
decltype(InFirstDeviceAndAllBroadcast) InFirstDeviceAndAllBroadcast =
DECORATE(&RawInFirstDeviceAndAllBroadcast, ThreadLocalCached);
decltype(OutFirstDeviceAndAllBroadcast) OutFirstDeviceAndAllBroadcast =
DECORATE(&RawOutFirstDeviceAndAllBroadcast, ThreadLocalCached);
namespace {
Maybe<Symbol<PlacedNdSbp>> RawPlacementAndRepeatFirstSbp(Symbol<PlacedNdSbp> placed_nd_sbp) {
const auto& first_sbp_parallel = placed_nd_sbp->nd_sbp()->sbp_parallel(0);
NdSbp out_nd_sbp;
for (int64_t i = 0; i < placed_nd_sbp->nd_sbp()->sbp_parallel_size(); ++i) {
out_nd_sbp.mutable_sbp_parallel()->Add()->CopyFrom(first_sbp_parallel);
}
return JUST(PlacedNdSbp::New(SymbolOf(out_nd_sbp), placed_nd_sbp->placement()));
}
static constexpr auto* PlacementAndRepeatFirstSbp =
DECORATE(&RawPlacementAndRepeatFirstSbp, ThreadLocalCached);
Maybe<BoxingDividor> RawInPlacementAndRepeatFirstSbp() {
return std::make_shared<BoxingDividor>(
"InPlacementAndRepeatFirstSbp",
[](Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) -> Maybe<Symbol<PlacedNdSbp>> {
return PlacementAndRepeatFirstSbp(in);
});
}
} // namespace
decltype(InPlacementAndRepeatFirstSbp) InPlacementAndRepeatFirstSbp =
DECORATE(&RawInPlacementAndRepeatFirstSbp, ThreadLocalCached);
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_
#define ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_
#include "oneflow/core/common/device_type.pb.h"
#include "oneflow/core/boxing/boxing_dividor.h"
namespace oneflow {
extern Maybe<BoxingDividor> (*ReplaceInDeviceType)(DeviceType device_type);
extern Maybe<BoxingDividor> (*ReplaceOutDeviceType)(DeviceType device_type);
extern Maybe<BoxingDividor> (*FlattenInHierarchy)();
extern Maybe<BoxingDividor> (*UnflattenInHierarchy)();
extern Maybe<BoxingDividor> (*UnflattenOutHierarchy)();
extern Maybe<BoxingDividor> (*OutPlacementAndPartialSum)();
extern Maybe<BoxingDividor> (*InPlacementAndBroadcast)();
extern Maybe<BoxingDividor> (*OutPlacementAndBroadcast)();
extern Maybe<BoxingDividor> (*InPlacementAndSplit)(int64_t axis);
extern Maybe<BoxingDividor> (*OutPlacementAndSplit)(int64_t axis);
extern Maybe<BoxingDividor> (*InFirstDeviceAndAllBroadcast)();
extern Maybe<BoxingDividor> (*OutFirstDeviceAndAllBroadcast)();
extern Maybe<BoxingDividor> (*InPlacementAndRepeatFirstSbp)();
} // namespace oneflow
#endif // ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/placed_nd_sbp.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/boxing/boxing_interpreter_status.h"
namespace oneflow {
namespace {
Maybe<BoxingInterpreterStatus> RawMakeBoxingInterpreterStatus(const std::string& boxing_name,
const Shape& logical_shape,
Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
std::vector<std::string> sorted_boxing_names{boxing_name};
BoxingInterpreterStatus status(SymbolOf(sorted_boxing_names), logical_shape, in, out);
return status;
}
Maybe<BoxingInterpreterStatus> RawMakeComposedBoxingInterpreterStatus(
const std::shared_ptr<BoxingInterpreterStatus>& lhs_status,
const std::shared_ptr<BoxingInterpreterStatus>& rhs_status) {
CHECK_OR_RETURN(lhs_status->dst_placed_nd_sbp()
== rhs_status->src_placed_nd_sbp()) // always true
<< Error::RuntimeError()
<< "Intermediate placed_nd_sbp must be equal when compose boxing interpreter status"
<< ". lhs_status.dst_nd_sbp: " << NdSbpToString(lhs_status->dst_placed_nd_sbp()->nd_sbp())
<< ", rhs_status.dst_nd_sbp: " << NdSbpToString(rhs_status->src_placed_nd_sbp()->nd_sbp())
<< ", lhs_status.dst_placement: "
<< *JUST(PlacementToString(lhs_status->dst_placed_nd_sbp()->placement()))
<< ", rhs_status.dst_placement: "
<< *JUST(PlacementToString(rhs_status->src_placed_nd_sbp()->placement()));
CHECK_OR_RETURN(lhs_status->logical_shape() == rhs_status->logical_shape()) // always true
<< Error::RuntimeError()
<< "Logical_shape must be equal when compose boxing interpreter status"
<< ". lhs_status.logical_shape: " << (lhs_status->logical_shape().ToString())
<< ". rhs_status.logical_shape: " << (rhs_status->logical_shape().ToString());
std::vector<std::string> sorted_boxing_names(*lhs_status->sorted_boxing_names());
sorted_boxing_names.insert(sorted_boxing_names.end(), rhs_status->sorted_boxing_names()->begin(),
rhs_status->sorted_boxing_names()->end());
std::vector<Symbol<PlacedNdSbp>> mid_placed_nd_sbp(*lhs_status->mid_placed_nd_sbp());
mid_placed_nd_sbp.emplace_back(lhs_status->dst_placed_nd_sbp());
mid_placed_nd_sbp.insert(mid_placed_nd_sbp.end(), rhs_status->mid_placed_nd_sbp()->begin(),
rhs_status->mid_placed_nd_sbp()->end());
BoxingInterpreterStatus status(sorted_boxing_names, lhs_status->logical_shape(),
lhs_status->src_placed_nd_sbp(), SymbolOf(mid_placed_nd_sbp),
rhs_status->dst_placed_nd_sbp());
return status;
}
} // namespace
decltype(MakeBoxingInterpreterStatus) MakeBoxingInterpreterStatus =
DECORATE(&RawMakeBoxingInterpreterStatus, ThreadLocalCachedCopiable);
decltype(MakeComposedBoxingInterpreterStatus) MakeComposedBoxingInterpreterStatus =
DECORATE(&RawMakeComposedBoxingInterpreterStatus, ThreadLocalCachedCopiable);
namespace {
Maybe<std::string> RawGetNdSbpRouting(Symbol<PlacedNdSbp> src_placed_nd_sbp,
Symbol<std::vector<Symbol<PlacedNdSbp>>> mid_placed_nd_sbp,
Symbol<PlacedNdSbp> dst_placed_nd_sbp) {
std::ostringstream ss;
ss << NdSbpToString(src_placed_nd_sbp->nd_sbp());
for (const auto& placed_nd_sbp : *mid_placed_nd_sbp) {
ss << " -> " << NdSbpToString(placed_nd_sbp->nd_sbp());
}
ss << " -> " << NdSbpToString(dst_placed_nd_sbp->nd_sbp());
return ss.str();
}
Maybe<std::string> RawGetPlacementRouting(
Symbol<PlacedNdSbp> src_placed_nd_sbp,
Symbol<std::vector<Symbol<PlacedNdSbp>>> mid_placed_nd_sbp,
Symbol<PlacedNdSbp> dst_placed_nd_sbp) {
std::ostringstream ss;
ss << *JUST(PlacementToString(src_placed_nd_sbp->placement()));
for (const auto& placed_nd_sbp : *mid_placed_nd_sbp) {
ss << " -> " << *JUST(PlacementToString(placed_nd_sbp->placement()));
}
ss << " -> " << *JUST(PlacementToString(dst_placed_nd_sbp->placement()));
return ss.str();
}
Maybe<std::string> RawGetBoxingDesc(Symbol<std::vector<std::string>> sorted_boxing_names) {
CHECK_OR_RETURN(!sorted_boxing_names->empty()) // always true
<< Error::RuntimeError() << "boxing_names of eager boxing status can't be empty!";
std::ostringstream ss;
ss << sorted_boxing_names->at(0);
for (size_t i = 1; i < sorted_boxing_names->size(); ++i) {
ss << " -> " << sorted_boxing_names->at(i);
}
return ss.str();
}
static constexpr auto* GetNdSbpRouting = DECORATE(&RawGetNdSbpRouting, ThreadLocalCached);
static constexpr auto* GetPlacementRouting = DECORATE(&RawGetPlacementRouting, ThreadLocalCached);
static constexpr auto* GetBoxingDesc = DECORATE(&RawGetBoxingDesc, ThreadLocalCached);
} // namespace
const std::string& BoxingInterpreterStatus::boxing_routing() const {
return *CHECK_JUST(GetBoxingDesc(sorted_boxing_names_));
}
const std::string& BoxingInterpreterStatus::nd_sbp_routing() const {
return *CHECK_JUST(GetNdSbpRouting(src_placed_nd_sbp_, mid_placed_nd_sbp_, dst_placed_nd_sbp_));
}
const std::string& BoxingInterpreterStatus::placement_routing() const {
return *CHECK_JUST(
GetPlacementRouting(src_placed_nd_sbp_, mid_placed_nd_sbp_, dst_placed_nd_sbp_));
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_
#define ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/framework/placed_nd_sbp.h"
#include "oneflow/core/common/shape.h"
namespace oneflow {
class BoxingInterpreterStatus;
extern Maybe<BoxingInterpreterStatus> (*MakeBoxingInterpreterStatus)(const std::string& boxing_name,
const Shape& logical_shape,
Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out);
extern Maybe<BoxingInterpreterStatus> (*MakeComposedBoxingInterpreterStatus)(
const std::shared_ptr<BoxingInterpreterStatus>& lhs_status,
const std::shared_ptr<BoxingInterpreterStatus>& rhs_status);
class BoxingInterpreterStatus final {
public:
BoxingInterpreterStatus(Symbol<std::vector<std::string>> sorted_boxing_names,
const Shape& logical_shape, Symbol<PlacedNdSbp> src_placed_nd_sbp,
Symbol<std::vector<Symbol<PlacedNdSbp>>> mid_placed_nd_sbp,
Symbol<PlacedNdSbp> dst_placed_nd_sbp)
: sorted_boxing_names_(sorted_boxing_names),
logical_shape_(logical_shape),
src_placed_nd_sbp_(src_placed_nd_sbp),
mid_placed_nd_sbp_(mid_placed_nd_sbp),
dst_placed_nd_sbp_(dst_placed_nd_sbp) {}
BoxingInterpreterStatus(Symbol<std::vector<std::string>> sorted_boxing_names,
const Shape& logical_shape, Symbol<PlacedNdSbp> src_placed_nd_sbp,
Symbol<PlacedNdSbp> dst_placed_nd_sbp)
: BoxingInterpreterStatus(sorted_boxing_names, logical_shape, src_placed_nd_sbp,
SymbolOf(std::vector<Symbol<PlacedNdSbp>>()), dst_placed_nd_sbp) {}
~BoxingInterpreterStatus() = default;
bool operator==(const BoxingInterpreterStatus& other) const {
return this->sorted_boxing_names_ == other.sorted_boxing_names_
&& this->src_placed_nd_sbp_ == other.src_placed_nd_sbp_
&& this->mid_placed_nd_sbp_ == other.mid_placed_nd_sbp_
&& this->dst_placed_nd_sbp_ == other.dst_placed_nd_sbp_;
}
// Getters
Symbol<std::vector<std::string>> sorted_boxing_names() const { return sorted_boxing_names_; }
const Shape& logical_shape() const { return logical_shape_; }
Symbol<PlacedNdSbp> src_placed_nd_sbp() const { return src_placed_nd_sbp_; }
Symbol<PlacedNdSbp> dst_placed_nd_sbp() const { return dst_placed_nd_sbp_; }
Symbol<std::vector<Symbol<PlacedNdSbp>>> mid_placed_nd_sbp() const { return mid_placed_nd_sbp_; }
const std::string& boxing_routing() const;
const std::string& nd_sbp_routing() const;
const std::string& placement_routing() const;
private:
Symbol<std::vector<std::string>> sorted_boxing_names_;
const Shape logical_shape_;
Symbol<PlacedNdSbp> src_placed_nd_sbp_;
Symbol<std::vector<Symbol<PlacedNdSbp>>> mid_placed_nd_sbp_;
Symbol<PlacedNdSbp> dst_placed_nd_sbp_;
};
} // namespace oneflow
namespace std {
template<>
struct hash<oneflow::BoxingInterpreterStatus> {
size_t operator()(const oneflow::BoxingInterpreterStatus& status) const {
size_t ret = 0;
for (const auto& boxing_name : *status.sorted_boxing_names()) {
ret ^= std::hash<string>()(boxing_name);
}
const auto& placed_nd_sbp_hash = std::hash<oneflow::PlacedNdSbp>();
ret ^= placed_nd_sbp_hash(*status.src_placed_nd_sbp());
for (const auto& mid_placed_nd_sbp : *status.mid_placed_nd_sbp()) {
ret ^= placed_nd_sbp_hash(*mid_placed_nd_sbp);
}
ret ^= placed_nd_sbp_hash(*status.dst_placed_nd_sbp());
return hash<size_t>()(ret);
}
};
} // namespace std
#endif // ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace {
Maybe<void> RawCheckCclP2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(NdSbpIsAllPartialSum(*in->nd_sbp()));
CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
static constexpr auto* CheckCclP2B = DECORATE(&RawCheckCclP2B, ThreadLocalCachedCopiable);
Maybe<void> RawCheckCclP2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(NdSbpIsAllPartialSum(*in->nd_sbp()));
CHECK_OR_RETURN(NdSbpIsAllSplit(*out->nd_sbp(), 0));
CHECK_GT_OR_RETURN(logical_shape.NumAxes(), 0);
CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0);
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
static constexpr auto* CheckCclP2S = DECORATE(&RawCheckCclP2S, ThreadLocalCachedCopiable);
Maybe<void> RawCheckCclS2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(NdSbpIsAllSplit(*in->nd_sbp(), 0));
CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));
CHECK_GT_OR_RETURN(logical_shape.NumAxes(), 0);
CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0);
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
static constexpr auto* CheckCclS2B = DECORATE(&RawCheckCclS2B, ThreadLocalCachedCopiable);
Maybe<void> RawCheckCclS2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1);
CHECK_EQ_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1);
CHECK_OR_RETURN(in->nd_sbp()->sbp_parallel(0).has_split_parallel());
CHECK_OR_RETURN(out->nd_sbp()->sbp_parallel(0).has_split_parallel());
CHECK_NE_OR_RETURN(in->nd_sbp()->sbp_parallel(0).split_parallel().axis(),
out->nd_sbp()->sbp_parallel(0).split_parallel().axis());
int64_t in_split_axis = in->nd_sbp()->sbp_parallel(0).split_parallel().axis();
int64_t out_split_axis = out->nd_sbp()->sbp_parallel(0).split_parallel().axis();
CHECK_GT_OR_RETURN(logical_shape.NumAxes(), in_split_axis);
CHECK_GT_OR_RETURN(logical_shape.NumAxes(), out_split_axis);
CHECK_OR_RETURN(logical_shape.At(in_split_axis) % in->placement()->parallel_num() == 0);
CHECK_OR_RETURN(logical_shape.At(out_split_axis) % in->placement()->parallel_num() == 0);
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
static constexpr auto* CheckCclS2S = DECORATE(&RawCheckCclS2S, ThreadLocalCachedCopiable);
} // namespace
Maybe<one::Tensor> CclP2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement())
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
return JUST(one::functional::ConsistentAllReduce(tensor));
}
Maybe<one::Tensor> CclP2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement())
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
return JUST(one::functional::ConsistentReduceScatter(tensor, "sum"));
}
Maybe<one::Tensor> CclS2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement())
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
return JUST(one::functional::ConsistentAllGather(tensor));
}
Maybe<one::Tensor> CclS2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement())
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
return JUST(one::functional::ConsistentS2S(tensor, *JUST(GetSbpList(out->nd_sbp()))));
}
COMMAND(RegisterBoxingFunction("ccl-p-to-b", CheckCclP2B, &CclP2B));
COMMAND(RegisterBoxingFunction("ccl-p-to-s", CheckCclP2S, &CclP2S));
COMMAND(RegisterBoxingFunction("ccl-s-to-b", CheckCclS2B, &CclS2B));
COMMAND(RegisterBoxingFunction("ccl-s-to-s", CheckCclS2S, &CclS2S));
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/parallel_desc.h"
namespace oneflow {
namespace {
Maybe<bool> IgnoringDeviceTypeEqual(Symbol<ParallelDesc> lhs, Symbol<ParallelDesc> rhs) {
return lhs == JUST(ReplaceDeviceType(rhs, lhs->device_type()));
}
} // namespace
// NOLINTBEGIN(maybe-need-error-msg)
Maybe<void> CheckCopyH2D(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
bool equal = JUST(IgnoringDeviceTypeEqual(in->placement(), out->placement()));
CHECK_OR_RETURN(equal);
CHECK_EQ_OR_RETURN(in->placement()->device_type(), DeviceType::kCPU);
CHECK_NE_OR_RETURN(out->placement()->device_type(), DeviceType::kCPU);
CHECK_OR_RETURN(in->nd_sbp() == out->nd_sbp());
return Maybe<void>::Ok();
}
Maybe<void> CheckCopyD2H(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
bool equal = JUST(IgnoringDeviceTypeEqual(in->placement(), out->placement()));
CHECK_OR_RETURN(equal);
CHECK_NE_OR_RETURN(in->placement()->device_type(), DeviceType::kCPU);
CHECK_EQ_OR_RETURN(out->placement()->device_type(), DeviceType::kCPU);
CHECK_OR_RETURN(in->nd_sbp() == out->nd_sbp());
return Maybe<void>::Ok();
}
// NOLINTEND(maybe-need-error-msg)
Maybe<one::Tensor> CopyBoxingFunction(const std::shared_ptr<one::Tensor>& tensor,
Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp())
<< Error::RuntimeError() << "The sbp of input tensor (" << NdSbpToString(tensor_nd_sbp)
<< ") must match the input sbp (" << NdSbpToString(in->nd_sbp()) << ")";
const auto& tensor_placement = JUST(tensor->parallel_desc());
CHECK_OR_RETURN(tensor_placement == in->placement())
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());
const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement()));
if (!out_parallel_id->has_value()) {
const std::string& device_type = tensor_placement->device_tag();
local_tensor = JUST(one::functional::Empty(
*JUST(GetPhysicalShape(*tensor->shape(), *tensor_nd_sbp, *tensor_placement, 0)),
tensor->dtype(), JUST(Device::New(device_type)), /*pin_memory=*/false));
}
const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype()));
}
COMMAND(RegisterBoxingFunction("copy-h2d", &CheckCopyH2D, &CopyBoxingFunction));
COMMAND(RegisterBoxingFunction("copy-d2h", &CheckCopyD2H, &CopyBoxingFunction));
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment