Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
......@@ -28,7 +28,7 @@ struct ReshapeCaptureState : public AutoGradCaptureState {
DimVector input_shape_vec;
};
class ReshapeOpExprGrad : public OpExprGradFunction<ReshapeCaptureState> {
class ReshapeGrad : public OpExprGradFunction<ReshapeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
......@@ -51,7 +51,34 @@ class ReshapeOpExprGrad : public OpExprGradFunction<ReshapeCaptureState> {
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("reshape", ReshapeOpExprGrad);
class ReshapeLikeGrad : public OpExprGradFunction<ReshapeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_OR_RETURN(!inputs.at(1)->requires_grad())
<< "ReshapeLikeOp's input[1] need not requires_grad.";
ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
Shape shape(ctx->input_shape_vec);
in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape));
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("reshape", ReshapeGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("reshape_like", ReshapeLikeGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct RMSNormCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool weight_requires_grad = false;
int x_index = -1;
int inv_rms_index = -1;
int weight_index = -1;
};
class RMSNormGrad : public OpExprGradFunction<RMSNormCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(RMSNormCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const RMSNormCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> RMSNormGrad::Capture(RMSNormCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
// (x, [weight])
CHECK_GE_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_LE_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
// (y, inv_rms)
CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg)
// save x
ctx->x_requires_grad = inputs[0]->requires_grad();
ctx->x_index = ctx->SaveTensorForBackward(inputs[0]);
// save weight
ctx->weight_requires_grad = false;
if (inputs.size() > 1) {
ctx->weight_requires_grad = inputs[1]->requires_grad();
ctx->weight_index = ctx->SaveTensorForBackward(inputs[1]);
}
// save inv_rms
if (ctx->x_requires_grad || ctx->weight_requires_grad) {
ctx->inv_rms_index = ctx->SaveTensorForBackward(outputs[1]);
}
return Maybe<void>::Ok();
}
Maybe<void> RMSNormGrad::Apply(const RMSNormCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
// (x, inv_rms) or (x, weight, inv_rms)
const auto& saved_tensors = ctx->SavedTensors();
CHECK_GE_OR_RETURN(saved_tensors.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_LE_OR_RETURN(saved_tensors.size(), 3); // NOLINT(maybe-need-error-msg)
// (dy, inv_rms_diff)
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads[0];
const auto& x = saved_tensors.at(ctx->x_index);
const auto& inv_rms = saved_tensors.at(ctx->inv_rms_index);
// (x_grad, weight_grad)
in_grads->resize(2);
if (ctx->x_requires_grad) {
if (saved_tensors.size() == 3) {
const auto& weight = saved_tensors.at(ctx->weight_index);
in_grads->at(0) = JUST(functional::RMSNormGrad(dy, x, inv_rms, weight, /*param_grad*/ false));
} else {
in_grads->at(0) =
JUST(functional::RMSNormGrad(dy, x, inv_rms, /*weight*/ NullOpt, /*param_grad*/ false));
}
}
if (ctx->weight_requires_grad) {
in_grads->at(1) =
JUST(functional::RMSNormGrad(dy, x, inv_rms, /*weight*/ NullOpt, /*param_grad*/ true));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("rms_norm", RMSNormGrad);
} // namespace one
} // namespace oneflow
......@@ -13,15 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
// FloorDiv derivatives function isn't exists. (author: zhengzekang)
struct ScalarFloorDivCaptureState : public AutoGradCaptureState {};
struct ScalarFloorDivCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
};
class ScalarFloorDiv : public OpExprGradFunction<ScalarFloorDivCaptureState> {
public:
......@@ -29,17 +31,20 @@ class ScalarFloorDiv : public OpExprGradFunction<ScalarFloorDivCaptureState> {
Maybe<void> Capture(ScalarFloorDivCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarFloorDivCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
UNIMPLEMENTED_THEN_RETURN() << "RuntimeError: derivative for floor_divide is not implemented";
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
JUST(VectorAt(*in_grads, 0)) = JUST(functional::ZerosLike(JUST(VectorAt(out_grads, 0))));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_floordiv", ScalarFloorDiv);
......
......@@ -55,7 +55,6 @@ class ScalarPow : public OpExprGradFunction<ScalarPowCaptureState> {
Maybe<void> Apply(const ScalarPowCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
MutableAttrMap attrs;
in_grads->resize(1);
if (ctx->requires_grad) {
in_grads->at(0) = JUST(functional::ScalarPowGrad(x, out_grads.at(0), ctx->operand));
......@@ -64,7 +63,6 @@ class ScalarPow : public OpExprGradFunction<ScalarPowCaptureState> {
}
private:
std::shared_ptr<OpExpr> grad_op_;
AttrMap base_attrs_;
};
......@@ -100,7 +98,6 @@ class ScalarReversePow : public OpExprGradFunction<ScalarPowCaptureState> {
Maybe<void> Apply(const ScalarPowCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors()[0];
MutableAttrMap attrs;
in_grads->resize(1);
if (ctx->requires_grad) {
(*in_grads)[0] = JUST(functional::ScalarReversePowGrad(x, out_grads[0], ctx->operand));
......@@ -109,7 +106,6 @@ class ScalarReversePow : public OpExprGradFunction<ScalarPowCaptureState> {
}
private:
std::shared_ptr<OpExpr> grad_op_;
AttrMap base_attrs_;
};
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct ScalarTruncDivCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
};
class ScalarTruncDiv : public OpExprGradFunction<ScalarTruncDivCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(ScalarTruncDivCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarTruncDivCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
JUST(VectorAt(*in_grads, 0)) = JUST(functional::ZerosLike(JUST(VectorAt(out_grads, 0))));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_truncdiv", ScalarTruncDiv);
} // namespace one
} // namespace oneflow
......@@ -98,7 +98,7 @@ class SliceUpdate : public OpExprGradFunction<SliceUpdateCaptureState> {
if (ctx->requires_grad_ref) {
ctx->value_shape = *(inputs[1]->shape());
if (inputs[1]->is_consistent()) { ctx->value_sbp = JUST(inputs[1]->nd_sbp()); }
if (inputs[1]->is_global()) { ctx->value_sbp = JUST(inputs[1]->nd_sbp()); }
}
return Maybe<void>::Ok();
}
......@@ -114,8 +114,7 @@ class SliceUpdate : public OpExprGradFunction<SliceUpdateCaptureState> {
JUST(out_grads[0]->device())));
} else {
const auto& parallel_desc = JUST(out_grads[0]->parallel_desc());
zeros =
JUST(functional::ConsistentConstant(ctx->value_shape, 0, out_grads[0]->dtype(),
zeros = JUST(functional::GlobalConstant(ctx->value_shape, 0, out_grads[0]->dtype(),
parallel_desc, *JUST(GetSbpList(ctx->value_sbp))));
}
(*in_grads)[0] = JUST(functional::SliceUpdate(out_grads[0], zeros, ctx->start, ctx->stop,
......
......@@ -22,7 +22,8 @@ namespace oneflow {
namespace one {
struct SmoothL1LossCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
float beta = 0.0;
};
......@@ -37,13 +38,13 @@ class SmoothL1Loss : public OpExprGradFunction<SmoothL1LossCaptureState> {
Maybe<void> Capture(SmoothL1LossCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->requires_grad = inputs.at(0)->requires_grad(); // prediction
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->SaveTensorForBackward(inputs.at(0)); // prediction
ctx->SaveTensorForBackward(inputs.at(1)); // label
ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input
ctx->target_requires_grad = inputs.at(1)->requires_grad(); // target
ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->beta = JUST(composed_attrs.GetAttr<float>("beta"));
......@@ -52,15 +53,15 @@ class SmoothL1Loss : public OpExprGradFunction<SmoothL1LossCaptureState> {
Maybe<void> Apply(const SmoothL1LossCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 2); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
const auto& input = ctx->SavedTensors().at(0);
const auto& target = ctx->SavedTensors().at(1);
const auto& grad = JUST(functional::SmoothL1LossGrad(out_grads[0], input, target, ctx->beta));
const auto& prediction = ctx->SavedTensors().at(0);
const auto& label = ctx->SavedTensors().at(1);
in_grads->at(0) =
JUST(functional::SmoothL1LossGrad(out_grads.at(0), prediction, label, ctx->beta));
if (ctx->input_requires_grad) { (*in_grads)[0] = grad; }
if (ctx->target_requires_grad) { (*in_grads)[1] = JUST(functional::Negative(grad)); }
return Maybe<void>::Ok();
}
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
......@@ -50,10 +51,10 @@ Maybe<void> SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyCaptureS
const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->depth = JUST(composed_attrs.GetAttr<int64_t>("depth"));
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->SaveTensorForBackward(outputs.at(0)); // prob
ctx->SaveTensorForBackward(inputs.at(1)); // label
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); // prob
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // label
return Maybe<void>::Ok();
}
......@@ -61,15 +62,14 @@ Maybe<void> SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyCapt
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(1);
const auto& prob = ctx->SavedTensors().at(0);
const auto& label = ctx->SavedTensors().at(1);
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("depth", ctx->depth));
const auto& dy = JUST(VectorAt(out_grads, 1));
const auto& prob = JUST(VectorAt(ctx->SavedTensors(), 0));
const auto& label = JUST(VectorAt(ctx->SavedTensors(), 1));
// SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not
// require gradient.
in_grads->resize(2);
in_grads->at(0) = JUST(functional::SparseSoftmaxCrossEntropyGrad(dy, prob, label, ctx->depth));
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::SparseSoftmaxCrossEntropyGrad(dy, prob, label, ctx->depth));
return Maybe<void>::Ok();
}
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct SparseSoftmaxCrossEntropyMsCaptureState : public AutoGradCaptureState {
int64_t depth = 0;
};
class SparseSoftmaxCrossEntropyMs
: public OpExprGradFunction<SparseSoftmaxCrossEntropyMsCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(SparseSoftmaxCrossEntropyMsCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const SparseSoftmaxCrossEntropyMsCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> SparseSoftmaxCrossEntropyMs::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> SparseSoftmaxCrossEntropyMs::Capture(SparseSoftmaxCrossEntropyMsCaptureState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->depth = JUST(composed_attrs.GetAttr<int64_t>("depth"));
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); // prob
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // label
return Maybe<void>::Ok();
}
Maybe<void> SparseSoftmaxCrossEntropyMs::Apply(const SparseSoftmaxCrossEntropyMsCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
const auto& dy = JUST(VectorAt(out_grads, 1));
const auto& prob = JUST(VectorAt(ctx->SavedTensors(), 0));
const auto& label = JUST(VectorAt(ctx->SavedTensors(), 1));
// SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not
// require gradient.
in_grads->resize(2);
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::SparseSoftmaxCrossEntropyMsGrad(dy, prob, label, ctx->depth));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_softmax_cross_entropy_ms", SparseSoftmaxCrossEntropyMs);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct TruncCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};
class Trunc : public OpExprGradFunction<TruncCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(TruncCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const TruncCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> Trunc::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Trunc::Capture(TruncCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
return Maybe<void>::Ok();
}
Maybe<void> Trunc::Apply(const TruncCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("trunc", Trunc);
} // namespace one
} // namespace oneflow
......@@ -73,8 +73,8 @@ Maybe<void> Unfold::Apply(const UnfoldInterpState* ctx, const TensorTuple& out_g
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
in_grads->at(0) =
JUST(functional::Fold(out_grads.at(0), ctx->data_format, ctx->output_size, ctx->kernel_size,
ctx->dilation_rate, ctx->padding, ctx->strides));
JUST(functional::Fold(out_grads.at(0), ctx->output_size, ctx->kernel_size, ctx->dilation_rate,
ctx->padding, ctx->strides, ctx->data_format));
return Maybe<void>::Ok();
}
......
......@@ -100,7 +100,7 @@ class UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DCaptureStat
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
if (composed_attrs.Has("output_size")) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
......@@ -112,7 +112,6 @@ class UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DCaptureStat
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad(
......@@ -151,7 +150,7 @@ class UpsampleBilinear2D : public OpExprGradFunction<UpsampleBilinear2DCaptureSt
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
if (composed_attrs.Has("output_size")) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
......@@ -163,7 +162,6 @@ class UpsampleBilinear2D : public OpExprGradFunction<UpsampleBilinear2DCaptureSt
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleBilinear2DGrad(
......@@ -200,7 +198,7 @@ class UpsampleLinear1D : public OpExprGradFunction<UpsampleLinear1DCaptureState>
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->scale_factor = JUST(composed_attrs.GetAttr<double>("scale_factor"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
if (composed_attrs.Has("output_size")) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
......@@ -212,7 +210,6 @@ class UpsampleLinear1D : public OpExprGradFunction<UpsampleLinear1DCaptureState>
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleLinear1DGrad(
......@@ -247,7 +244,7 @@ class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureStat
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->scale_factor = JUST(composed_attrs.GetAttr<double>("scale_factor"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
if (composed_attrs.Has("output_size")) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
......@@ -259,7 +256,6 @@ class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureStat
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(
......@@ -298,7 +294,7 @@ class UpsampleBicubic2D : public OpExprGradFunction<UpsampleBicubic2DCaptureStat
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
if (composed_attrs.Has("output_size")) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
......@@ -310,7 +306,6 @@ class UpsampleBicubic2D : public OpExprGradFunction<UpsampleBicubic2DCaptureStat
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleBicubic2DGrad(
......@@ -348,7 +343,7 @@ class UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DCaptureStat
ctx->depth_scale = JUST(composed_attrs.GetAttr<double>("depth_scale"));
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
if (composed_attrs.Has("output_size")) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
......@@ -360,7 +355,6 @@ class UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DCaptureStat
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad(
......@@ -401,7 +395,7 @@ class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCapture
ctx->height_scale = JUST(composed_attrs.GetAttr<double>("height_scale"));
ctx->width_scale = JUST(composed_attrs.GetAttr<double>("width_scale"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
if (base_attrs_.find("output_size") != base_attrs_.end()) {
if (composed_attrs.Has("output_size")) {
ctx->output_size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("output_size"));
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
......@@ -413,7 +407,6 @@ class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCapture
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
MutableAttrMap attrs;
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleTrilinear3DGrad(
......@@ -430,4 +423,4 @@ class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCapture
REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_trilinear_3d", UpsampleTrilinear3D);
} // namespace one
} // namespace oneflow
\ No newline at end of file
} // namespace oneflow
......@@ -68,6 +68,9 @@ Maybe<void> Variance::Apply(const VarianceState* ctx, const TensorTuple& out_gra
TensorTuple* in_grads) const {
// TODO(): replace it using kernel
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
DataType data_type = x->dtype()->data_type();
CHECK_NE_OR_RETURN(data_type, DataType::kBFloat16)
<< Error::RuntimeError() << "Variance op not support backward for bfloat16 yet!";
size_t correction = ctx->unbiased ? 1 : 0;
size_t elem_cnt = 1;
CHECK_OR_RETURN(ctx->axis.size() > 0)
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct VectorMatrixProductCaptureState : public AutoGradCaptureState {
bool requires_grad_a = false;
bool requires_grad_b = false;
size_t a_index = 0;
size_t b_index = 1;
};
class VectorMatrixProduct : public OpExprGradFunction<VectorMatrixProductCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(VectorMatrixProductCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const VectorMatrixProductCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
protected:
AttrMap base_attrs_;
};
Maybe<void> VectorMatrixProduct::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be null. ";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> VectorMatrixProduct::Capture(VectorMatrixProductCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad_a = JUST(VectorAt(inputs, 0))->requires_grad();
ctx->requires_grad_b = JUST(VectorAt(inputs, 1))->requires_grad();
if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
if (ctx->requires_grad_a) {
ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // input b
}
if (ctx->requires_grad_b) {
ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input a
}
return Maybe<void>::Ok();
}
Maybe<void> VectorMatrixProduct::Apply(const VectorMatrixProductCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "Out grad size should be equal to 1. ";
in_grads->resize(2);
if (ctx->requires_grad_a) {
const auto& input_b = JUST(VectorAt(ctx->SavedTensors(), ctx->b_index));
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::VectorMatrixProductGradA(JUST(VectorAt(out_grads, 0)), input_b));
}
if (ctx->requires_grad_b) {
const auto& input_a = JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->a_index));
JUST(VectorAt(*in_grads, 1)) =
JUST(functional::VectorMatrixProductGradB(JUST(VectorAt(out_grads, 0)), input_a));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("vector_matrix_product", VectorMatrixProduct);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cstddef>
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct BaseActivationGradGradCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool grad_requires_grad = false;
};
typedef Maybe<one::Tensor> (*NoParamActivationBwFunc)(const std::shared_ptr<one::Tensor>&,
const std::shared_ptr<one::Tensor>&);
template<NoParamActivationBwFunc BwFunc, NoParamActivationBwFunc BwBwFunc>
class NoParamActivationGradGrad : public OpExprGradFunction<BaseActivationGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(BaseActivationGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// dy, x
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(1)->requires_grad();
ctx->grad_requires_grad = inputs.at(0)->requires_grad();
if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(1));
if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const BaseActivationGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
const auto& x = ctx->SavedTensors().at(0);
if (ctx->x_requires_grad) {
const auto& grad = ctx->SavedTensors().at(1);
in_grads->at(1) = JUST(functional::Mul(out_grads.at(0), JUST(BwBwFunc(x, grad))));
}
if (ctx->grad_requires_grad) { in_grads->at(0) = JUST(BwFunc(out_grads.at(0), x)); }
return Maybe<void>::Ok();
}
};
#define INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(op_type_name, op_cls) \
class op_cls##GradGradCls final \
: public NoParamActivationGradGrad<functional::op_cls##Grad, functional::op_cls##GradGrad> { \
}; \
REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls);
// first order backward param: (dy, x)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("mish_grad", Mish)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("gelu_grad", Gelu)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("silu_grad", Silu)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("selu_grad", Selu)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("softsign_grad", SoftSign)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("hardsigmoid_grad", HardSigmoid)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS("hardswish_grad", HardSwish)
#undef INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS
struct HardShrinkGradGradCaptureState : public AutoGradCaptureState {
bool y_requires_grad = false;
bool grad_requires_grad = false;
double lambd = 0.5;
};
class HardShrinkGradGrad : public OpExprGradFunction<HardShrinkGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(HardShrinkGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// y, dy
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->y_requires_grad = inputs.at(0)->requires_grad();
ctx->grad_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->y_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->lambd = JUST(composed_attrs.GetAttr<double>("lambd"));
if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const HardShrinkGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->y_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }
if (ctx->grad_requires_grad) {
const auto& y = ctx->SavedTensors().at(0);
in_grads->at(1) = JUST(functional::HardShrinkGrad(y, out_grads.at(0), ctx->lambd));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct SoftShrinkGradGradCaptureState : public AutoGradCaptureState {
bool y_requires_grad = false;
bool grad_requires_grad = false;
double alpha = 0.5;
};
class SoftShrinkGradGrad : public OpExprGradFunction<SoftShrinkGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(SoftShrinkGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// y, dy
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->y_requires_grad = inputs.at(0)->requires_grad();
ctx->grad_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->y_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->alpha = JUST(composed_attrs.GetAttr<double>("alpha"));
if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SoftShrinkGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->y_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }
if (ctx->grad_requires_grad) {
const auto& y = ctx->SavedTensors().at(0);
in_grads->at(1) = JUST(functional::SoftShrinkGrad(y, out_grads.at(0), ctx->alpha));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct ReluGradGradCaptureState : public AutoGradCaptureState {
bool y_requires_grad = false;
bool grad_requires_grad = false;
};
class ReluGradGrad : public OpExprGradFunction<ReluGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(ReluGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// dy, y
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->y_requires_grad = inputs.at(1)->requires_grad();
ctx->grad_requires_grad = inputs.at(0)->requires_grad();
if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ReluGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->y_requires_grad) { in_grads->at(1) = JUST(functional::ZerosLike(out_grads.at(0))); }
if (ctx->grad_requires_grad) {
const auto& y = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::ReluGrad(out_grads.at(0), y));
}
return Maybe<void>::Ok();
}
};
struct LeakyReluGradGradCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool grad_requires_grad = false;
float alpha = 0.01;
};
class LeakyReluGradGrad : public OpExprGradFunction<LeakyReluGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(LeakyReluGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// x, dy
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->grad_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->alpha = JUST(composed_attrs.GetAttr<float>("alpha"));
if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const LeakyReluGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }
if (ctx->grad_requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(1) = JUST(functional::LeakyReluGrad(x, out_grads.at(0), ctx->alpha));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct SoftplusGradGradCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool grad_requires_grad = false;
double beta = 1.0;
double threshold = 20.0;
};
class SoftplusGradGrad : public OpExprGradFunction<SoftplusGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(SoftplusGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// x, dy
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->grad_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->beta = JUST(composed_attrs.GetAttr<double>("beta"));
ctx->threshold = JUST(composed_attrs.GetAttr<double>("threshold"));
ctx->SaveTensorForBackward(inputs.at(0));
if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SoftplusGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
const auto& x = ctx->SavedTensors().at(0);
if (ctx->x_requires_grad) {
const auto& grad = ctx->SavedTensors().at(1);
in_grads->at(0) = JUST(functional::Mul(
out_grads.at(0), JUST(functional::SoftplusGradGrad(x, grad, ctx->beta, ctx->threshold))));
}
if (ctx->grad_requires_grad) {
in_grads->at(1) =
JUST(functional::SoftplusGrad(x, out_grads.at(0), ctx->beta, ctx->threshold));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct HardTanhGradGradCaptureState : public AutoGradCaptureState {
bool y_requires_grad = false;
bool grad_requires_grad = false;
double min_val = -1.0;
double max_val = 1.0;
};
class HardTanhGradGrad : public OpExprGradFunction<HardTanhGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(HardTanhGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// y, dy
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->y_requires_grad = inputs.at(0)->requires_grad();
ctx->grad_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->y_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->min_val = JUST(composed_attrs.GetAttr<double>("min_val"));
ctx->max_val = JUST(composed_attrs.GetAttr<double>("max_val"));
if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const HardTanhGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->y_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }
if (ctx->grad_requires_grad) {
const auto& y = ctx->SavedTensors().at(0);
in_grads->at(1) =
JUST(functional::HardTanhGrad(y, out_grads.at(0), ctx->min_val, ctx->max_val));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct EluGradGradCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool grad_requires_grad = false;
double alpha = 1.0;
};
class EluGradGrad : public OpExprGradFunction<EluGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(EluGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// x, dy
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->grad_requires_grad = inputs.at(1)->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->alpha = JUST(composed_attrs.GetAttr<double>("alpha"));
if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const EluGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
const auto& x = ctx->SavedTensors().at(0);
if (ctx->x_requires_grad) {
const auto& grad = ctx->SavedTensors().at(1);
in_grads->at(0) = JUST(
functional::Mul(out_grads.at(0), JUST(functional::EluGradGrad(x, grad, ctx->alpha))));
}
if (ctx->grad_requires_grad) {
in_grads->at(1) = JUST(functional::EluGrad(x, out_grads.at(0), ctx->alpha));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
class CeluGradGrad : public EluGradGrad {
public:
Maybe<void> Apply(const EluGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
const auto& x = ctx->SavedTensors().at(0);
if (ctx->x_requires_grad) {
const auto& grad = ctx->SavedTensors().at(1);
in_grads->at(0) = JUST(
functional::CeluGradGrad(x, JUST(functional::Mul(out_grads.at(0), (grad))), ctx->alpha));
}
if (ctx->grad_requires_grad) {
in_grads->at(1) = JUST(functional::CeluGrad(x, out_grads.at(0), ctx->alpha));
}
return Maybe<void>::Ok();
}
};
struct PReluGradGradCaptureState : public AutoGradCaptureState {
bool grad_requires_grad = false;
bool input_requires_grad = false;
bool alpha_requires_grad = false;
size_t grad_index = 0;
size_t input_index = 1;
size_t alpha_index = 2;
};
class PReluGradGrad : public OpExprGradFunction<PReluGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(PReluGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// dy, x, alpha
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs.at(0)->requires_grad(); // grad
ctx->input_requires_grad = inputs.at(1)->requires_grad(); // input
ctx->alpha_requires_grad = inputs.at(2)->requires_grad(); // alpha
ctx->input_index = ctx->SaveTensorForBackward(inputs.at(1));
ctx->alpha_index = ctx->SaveTensorForBackward(inputs.at(2));
ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const PReluGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(3);
const auto& input = ctx->SavedTensors().at(ctx->input_index);
const auto& alpha = ctx->SavedTensors().at(ctx->alpha_index);
const auto& grad = ctx->SavedTensors().at(ctx->grad_index);
const auto& grad_for_input = out_grads.at(0);
const auto& grad_for_alpha = out_grads.at(1);
const auto& condition = JUST(functional::ScalarLogicalLess(input, Scalar(0.0)));
const auto& zero_grad = JUST(functional::ZerosLike(alpha)); // alpha can broadcast to input
if (ctx->grad_requires_grad) {
auto input_mul_grad = JUST(functional::Mul(alpha, grad_for_input));
auto alpha_mul_grad = JUST(functional::Mul(input, grad_for_alpha));
auto result = JUST(functional::Add(input_mul_grad, alpha_mul_grad, /*alpha=*/Scalar(1.0),
/*inplace*/ false));
in_grads->at(0) = JUST(functional::Where(condition, result, grad_for_input));
}
if (ctx->input_requires_grad) {
auto result = JUST(functional::Mul(grad, grad_for_alpha));
in_grads->at(1) = JUST(functional::Where(condition, result, zero_grad));
}
if (ctx->alpha_requires_grad) {
auto result = JUST(functional::Mul(grad, grad_for_input));
in_grads->at(2) = JUST(functional::Where(condition, result, zero_grad));
}
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
};
struct ThresholdGradGradCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool grad_requires_grad = false;
double threshold = 0.0;
};
class ThresholdGradGrad : public OpExprGradFunction<ThresholdGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(ThresholdGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// x, dy
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->grad_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->x_requires_grad && !ctx->grad_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->threshold = JUST(composed_attrs.GetAttr<double>("threshold_val"));
if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ThresholdGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::ZerosLike(out_grads.at(0))); }
if (ctx->grad_requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(1) = JUST(functional::ThresholdGrad(x, out_grads.at(0), ctx->threshold));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("relu_grad", ReluGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("elu_grad", EluGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("celu_grad", CeluGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("prelu_grad", PReluGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardshrink_grad", HardShrinkGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("softshrink_grad", SoftShrinkGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("leaky_relu_grad", LeakyReluGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardtanh_grad", HardTanhGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("threshold_grad", ThresholdGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("softplus_grad", SoftplusGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct AdaptiveAvgPoolNDGradGradCaptureState : public AutoGradCaptureState {
bool input_requires_grad = false;
bool grad_requires_grad = false;
std::vector<int64_t> pool_output_size;
};
template<int ndims>
class AdaptiveAvgPoolNdNdGradGrad
: public OpExprGradFunction<AdaptiveAvgPoolNDGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(AdaptiveAvgPoolNDGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// dy, x
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs[0]->requires_grad();
ctx->input_requires_grad = inputs[1]->requires_grad();
if (ctx->grad_requires_grad) {
const auto& grad_shape = *inputs[0]->shape();
if (ndims == 1) {
ctx->pool_output_size = {grad_shape[grad_shape.size() - 1]};
} else if (ndims == 2) {
ctx->pool_output_size = {grad_shape[grad_shape.size() - 2],
grad_shape[grad_shape.size() - 1]};
} else if (ndims == 3) {
ctx->pool_output_size = {grad_shape[grad_shape.size() - 3],
grad_shape[grad_shape.size() - 2],
grad_shape[grad_shape.size() - 1]};
} else {
UNIMPLEMENTED_THEN_RETURN();
}
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const AdaptiveAvgPoolNDGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->grad_requires_grad) {
if (ndims == 1) {
(*in_grads)[0] = JUST(functional::AdaptiveAvgPool1D(out_grads[0], ctx->pool_output_size));
} else if (ndims == 2) {
(*in_grads)[0] = JUST(functional::AdaptiveAvgPool2D(out_grads[0], ctx->pool_output_size));
} else if (ndims == 3) {
(*in_grads)[0] = JUST(functional::AdaptiveAvgPool3D(out_grads[0], ctx->pool_output_size));
} else {
UNIMPLEMENTED_THEN_RETURN();
}
}
if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); }
return Maybe<void>::Ok();
}
};
struct AvgPoolGradGradCaptureState : public AutoGradCaptureState {
bool input_requires_grad = false;
bool grad_requires_grad = false;
std::string data_format;
std::vector<int32_t> padding;
std::vector<int32_t> kernel_size;
std::vector<int32_t> stride;
bool ceil_mode = false;
bool count_include_pad = false;
int32_t divisor_override = 0;
};
class AvgPoolNdGradGrad : public OpExprGradFunction<AvgPoolGradGradCaptureState> {
public:
virtual ~AvgPoolNdGradGrad() = default;
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(AvgPoolGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// dy, x
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs[0]->requires_grad();
ctx->input_requires_grad = inputs[1]->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding"));
ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("stride"));
ctx->ceil_mode = JUST(composed_attrs.GetAttr<bool>("ceil_mode"));
ctx->count_include_pad = JUST(composed_attrs.GetAttr<bool>("count_include_pad"));
ctx->divisor_override = JUST(composed_attrs.GetAttr<int32_t>("divisor_override"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const AvgPoolGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->grad_requires_grad) {
int32_t ndims = ctx->kernel_size.size();
const auto pool_op =
(ndims == 1 ? functional::AvgPool1D
: (ndims == 2 ? functional::AvgPool2D
: (ndims == 3 ? functional::AvgPool3D : nullptr)));
CHECK_NOTNULL_OR_RETURN(pool_op); // NOLINT(maybe-need-error-msg)
(*in_grads)[0] =
JUST(pool_op(out_grads[0], ctx->kernel_size, ctx->stride, ctx->padding, ctx->ceil_mode,
ctx->count_include_pad, ctx->divisor_override, ctx->data_format));
}
if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); }
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_1d_grad", AvgPoolNdGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_2d_grad", AvgPoolNdGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_3d_grad", AvgPoolNdGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool1d_grad", AdaptiveAvgPoolNdNdGradGrad<1>);
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool2d_grad", AdaptiveAvgPoolNdNdGradGrad<2>);
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool3d_grad", AdaptiveAvgPoolNdNdGradGrad<3>);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct BinaryCrossEntropyGradGradCaptureState : public AutoGradCaptureState {
bool grad_requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
bool has_weight = false;
};
class BinaryCrossEntropyGradGrad
: public OpExprGradFunction<BinaryCrossEntropyGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BinaryCrossEntropyGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> BinaryCrossEntropyGradGrad::Init(const OpExpr& op) { return Maybe<void>::Ok(); }
Maybe<void> BinaryCrossEntropyGradGrad::Capture(BinaryCrossEntropyGradGradCaptureState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
// dy, input, target[, weight]
CHECK_OR_RETURN(inputs.size() >= 3 && inputs.size() <= 4); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs[0]->requires_grad();
ctx->input_requires_grad = inputs[1]->requires_grad();
ctx->target_requires_grad = inputs[2]->requires_grad();
ctx->has_weight = inputs.size() == 4;
ctx->SaveTensorForBackward(inputs[0]); // grad
ctx->SaveTensorForBackward(inputs[1]); // input
ctx->SaveTensorForBackward(inputs[2]); // target
if (ctx->has_weight) {
ctx->SaveTensorForBackward(inputs[3]); // weight
}
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyGradGrad::Apply(const BinaryCrossEntropyGradGradCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),
3 + ctx->has_weight); // NOLINT(maybe-need-error-msg)
in_grads->resize(3 + ctx->has_weight);
const auto& grad = ctx->SavedTensors()[0];
const auto& input = ctx->SavedTensors()[1];
const auto& target = ctx->SavedTensors()[2];
// dx = grad * [-target/input + (1-target)/(1-input)]
// grad_for_grad = out_grad * [-target/input + (1-target)/(1-input)]
// grad_for_input = out_grad * grad * [target/(input*input) + (1-target)/((1-input)*(1-input))]
// = out_grad * grad * [(input*input-2*input*target+target)/(input*(1-input))^2]
// grad_for_target = out_grad * grad * [1/(input*(1-input))]
if (ctx->grad_requires_grad) {
const auto& weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[3]) : NullOpt;
(*in_grads)[0] =
JUST(functional::BinaryCrossEntropyLossGrad(out_grads[0], input, target, weight));
}
if (ctx->input_requires_grad) {
auto one_sub_input = JUST(functional::ScalarSub(1, input, /*alpha=*/1));
auto input_mul_target = JUST(functional::Mul(input, target));
auto numerator =
JUST(functional::sequence_function(functional::Square)
.then(std::bind(functional::Sub, std::placeholders::_1, input_mul_target,
/*alpha=*/2, /*inplace=*/false))
.then([&target](const std::shared_ptr<Tensor>& in) {
return functional::Add(in, target, /*alpha=*/1, /*inplace=*/false);
})
.call(input));
auto res = JUST(functional::sequence_function(functional::Mul)
.then(functional::Square)
.then(std::bind(functional::Div, numerator, std::placeholders::_1))
.then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))
.then(std::bind(functional::Mul, std::placeholders::_1, grad))
.call(input, one_sub_input));
(*in_grads)[1] = ctx->has_weight ? JUST(functional::Mul(ctx->SavedTensors()[3], res)) : res;
}
if (ctx->target_requires_grad) {
auto input_sub_one = JUST(functional::ScalarAdd(-1, input, /*alpha=*/1));
auto res = JUST(functional::sequence_function(functional::Mul)
.then(std::bind(functional::LogGrad, std::placeholders::_1, out_grads[0]))
.then(std::bind(functional::Mul, std::placeholders::_1, grad))
.call(input, input_sub_one));
(*in_grads)[2] = ctx->has_weight ? JUST(functional::Mul(ctx->SavedTensors()[3], res)) : res;
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_grad", BinaryCrossEntropyGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct BinaryCrossEntropyWithLogitsGradGradCaptureState : public AutoGradCaptureState {
bool grad_requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
bool has_weight = false;
bool has_pos_weight = false;
};
class BinaryCrossEntropyWithLogitsGradGrad
: public OpExprGradFunction<BinaryCrossEntropyWithLogitsGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> BinaryCrossEntropyWithLogitsGradGrad::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogitsGradGrad::Capture(
BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
// dy, input, target[, weight][, pos_weight]
CHECK_OR_RETURN(inputs.size() >= 3 && inputs.size() <= 5); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs[0]->requires_grad();
ctx->input_requires_grad = inputs[1]->requires_grad();
ctx->target_requires_grad = inputs[2]->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->has_pos_weight = JUST(composed_attrs.GetAttr<bool>("has_pos_weight"));
ctx->has_weight = inputs.size() == 5 || (inputs.size() == 4 && !ctx->has_pos_weight);
ctx->SaveTensorForBackward(inputs[0]); // grad
ctx->SaveTensorForBackward(inputs[1]); // input
ctx->SaveTensorForBackward(inputs[2]); // target
if (inputs.size() == 4) {
ctx->SaveTensorForBackward(inputs[3]); // weight or pos_weight
}
if (inputs.size() == 5) {
ctx->SaveTensorForBackward(inputs[3]); // weight
ctx->SaveTensorForBackward(inputs[4]); // pos_weight
}
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogitsGradGrad::Apply(
const BinaryCrossEntropyWithLogitsGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),
3 + ctx->has_weight + ctx->has_pos_weight); // NOLINT(maybe-need-error-msg)
in_grads->resize(3 + ctx->has_weight + ctx->has_pos_weight);
const auto& grad = ctx->SavedTensors()[0];
const auto& input = ctx->SavedTensors()[1];
const auto& target = ctx->SavedTensors()[2];
const size_t pos_weight_index = ctx->has_weight ? 4 : 3;
const auto& weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[3]) : NullOpt;
const auto& pos_weight =
ctx->has_pos_weight ? Optional<one::Tensor>(ctx->SavedTensors()[pos_weight_index]) : NullOpt;
// dx = grad * weight * (-target*(1-input.sigmoid())*pos_weight + input.sigmoid()*(1-target))
// grad_for_input = out_grad * grad * weight * sig * (1-sig) * [pos_weight * target + 1 - target]
// grad_for_target = -out_grad * grad * weight * [pos_weight + sig - pos_weight * sig]
if (ctx->grad_requires_grad) {
(*in_grads)[0] = JUST(functional::BinaryCrossEntropyWithLogitsLossGrad(
out_grads[0], input, target, weight, pos_weight));
}
if (ctx->input_requires_grad) {
auto res = JUST(functional::sequence_function(functional::Sigmoid)
.then(std::bind(functional::SigmoidGrad, std::placeholders::_1, grad))
.then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))
.call(input));
if (ctx->has_pos_weight) {
res = JUST(functional::sequence_function(functional::Mul)
.then([](const std::shared_ptr<Tensor>& input) {
return functional::ScalarAdd(1, input, /*alpha=*/Scalar(1));
})
.then(std::bind(functional::Sub, std::placeholders::_1, target, /*alpha=*/1,
/*inplace=*/false))
.then(std::bind(functional::Mul, std::placeholders::_1, res))
.call(JUST(pos_weight), target));
}
if (ctx->has_weight) { res = JUST(functional::Mul(res, JUST(weight))); }
(*in_grads)[1] = res;
}
if (ctx->target_requires_grad) {
auto res = JUST(functional::sequence_function(functional::Mul)
.then(functional::Negative)
.call(out_grads[0], grad));
if (ctx->has_pos_weight) {
auto sig = JUST(functional::Sigmoid(input));
auto one_sub_sig = JUST(functional::ScalarSub(1, sig, /*alpha=*/1));
res = JUST(functional::sequence_function(functional::Mul)
.then([&sig](const std::shared_ptr<Tensor>& input) {
return functional::Add(input, sig, /*alpha=*/Scalar(1), /*inplace=*/false);
})
.then(std::bind(functional::Mul, std::placeholders::_1, res))
.call(one_sub_sig, JUST(pos_weight)));
}
if (ctx->has_weight) { res = JUST(functional::Mul(res, JUST(weight))); }
(*in_grads)[2] = res;
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits_grad",
BinaryCrossEntropyWithLogitsGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState : public AutoGradCaptureState {
bool grad_requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
size_t grad_index = 0;
size_t input_index = 0;
size_t target_index = 0;
};
class BinaryCrossEntropyWithLogitsReduceMeanGradGrad
: public OpExprGradFunction<BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override;
};
Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanGradGrad::Init(const OpExpr& op) {
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanGradGrad::Capture(
BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
// dy, input, target
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs[0]->requires_grad();
ctx->input_requires_grad = inputs[1]->requires_grad();
ctx->target_requires_grad = inputs[2]->requires_grad();
if (ctx->input_requires_grad || ctx->target_requires_grad) {
ctx->grad_index = ctx->SaveTensorForBackward(inputs[0]); // grad
}
if (ctx->input_requires_grad || ctx->grad_requires_grad) {
ctx->input_index = ctx->SaveTensorForBackward(inputs[1]); // input
}
if (ctx->grad_requires_grad) {
ctx->target_index = ctx->SaveTensorForBackward(inputs[2]); // target
}
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogitsReduceMeanGradGrad::Apply(
const BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(3);
// dx = grad * weight * (input.sigmoid() - target)
// grad_for_input = out_grad * grad * weight * sig * (1-sig)
// grad_for_target = -out_grad * grad * weight
if (ctx->grad_requires_grad) {
const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index));
const auto& target = JUST(VectorAt(ctx->SavedTensors(), ctx->target_index));
(*in_grads)[0] = JUST(
functional::sequence_function(functional::Sigmoid)
.then(std::bind(functional::Sub, std::placeholders::_1, target, /*alpha=*/1,
/*inplace=*/false))
.then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))
.then(std::bind(functional::ReduceMean, std::placeholders::_1, std::vector<int32_t>{},
/*keepdim=*/false))
.call(input));
}
if (ctx->input_requires_grad) {
const auto& grad = JUST(VectorAt(ctx->SavedTensors(), ctx->grad_index));
const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index));
const auto& mean_grad = JUST(functional::ScalarMul(1.0 / out_grads[0]->nelement(), grad));
(*in_grads)[1] =
JUST(functional::sequence_function(functional::Sigmoid)
.then(std::bind(functional::SigmoidGrad, std::placeholders::_1, out_grads[0]))
.then(std::bind(functional::Mul, std::placeholders::_1, mean_grad))
.call(input));
}
if (ctx->target_requires_grad) {
const auto& grad = JUST(VectorAt(ctx->SavedTensors(), ctx->grad_index));
const auto& mean_grad = JUST(functional::ScalarMul(1.0 / out_grads[0]->nelement(), grad));
(*in_grads)[2] = JUST(functional::sequence_function(functional::Mul)
.then(functional::Negative)
.call(out_grads[0], mean_grad));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits_reduce_mean_grad",
BinaryCrossEntropyWithLogitsReduceMeanGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct ConvDataGradGradCaptureState : public AutoGradCaptureState {
bool w_requires_grad = false;
bool grad_requires_grad = false;
size_t w_index = 0;
size_t grad_index = 0;
std::string data_format;
std::vector<int32_t> padding_before;
std::vector<int32_t> kernel_size;
std::vector<int32_t> strides;
std::vector<int32_t> dilation_rate;
int32_t groups = 0;
};
class ConvDataGradGrad : public OpExprGradFunction<ConvDataGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(ConvDataGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const ConvDataGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> ConvDataGradGrad::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> ConvDataGradGrad::Capture(ConvDataGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
// input: dy, w, x_like, [add to output]
// output: dx
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->w_requires_grad = inputs.at(1)->requires_grad();
ctx->grad_requires_grad = inputs.at(0)->requires_grad();
if (ctx->grad_requires_grad) { ctx->w_index = ctx->SaveTensorForBackward(inputs.at(1)); }
if (ctx->w_requires_grad) { ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0)); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before"));
ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("strides"));
ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation_rate"));
ctx->groups = JUST(composed_attrs.GetAttr<int32_t>("groups"));
return Maybe<void>::Ok();
}
Maybe<void> ConvDataGradGrad::Apply(const ConvDataGradGradCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
in_grads->resize(3);
size_t num_spatial_dims = ctx->kernel_size.size();
// first order forward: ConvND
// x * w = y ( * => convolution)
// first order backward:
// x_grad = y_grad * w.rot180 (y.shape * w.shape -> x.shape) call ConvDataGrad
// w_grad = x * y_grad (x.shape * y.shape -> w.shape) call ConvFilterGrad
// second order forward (first order backward): ConvDataGrad
// y_grad * w.rot180 = x_grad
// second order forward:
// w_grad_grad = out_grads_x * y_grad (x.shape * y.shape -> w.shape) call ConvFilterGrad
// grad_for_y_grad = out_grads_x * w (x.shape * w.shape -> y.shape) call ConvND
// w_grad_grad
if (ctx->w_requires_grad) {
const auto& grad = ctx->SavedTensors().at(ctx->grad_index);
in_grads->at(1) = JUST(functional::ConvFilterGrad(
grad, out_grads.at(0), num_spatial_dims, ctx->kernel_size, ctx->strides,
ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));
}
// grad_for_y_grad
if (ctx->grad_requires_grad) {
const auto& w = ctx->SavedTensors().at(ctx->w_index);
const int32_t ndims = ctx->kernel_size.size();
const auto conv_op = (ndims == 1 ? functional::Conv1d
: (ndims == 2 ? functional::Conv2d
: (ndims == 3 ? functional::Conv3d : nullptr)));
CHECK_NOTNULL_OR_RETURN(conv_op); // NOLINT(maybe-need-error-msg)
in_grads->at(0) =
JUST(conv_op(out_grads.at(0), w, Optional<Tensor>(), ctx->strides, ctx->padding_before,
ctx->dilation_rate, ctx->groups, ctx->data_format));
}
return Maybe<void>::Ok();
}
struct ConvFilterGradGradCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool grad_requires_grad = false;
size_t x_index = 0;
size_t grad_index = 0;
std::string data_format;
std::vector<int32_t> padding_before;
std::vector<int32_t> kernel_size;
std::vector<int32_t> strides;
std::vector<int32_t> dilation_rate;
int32_t groups = 0;
};
class ConvFilterGradGrad : public OpExprGradFunction<ConvFilterGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(ConvFilterGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const ConvFilterGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> ConvFilterGradGrad::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> ConvFilterGradGrad::Capture(ConvFilterGradGradCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
// input: dy, x
// output: dw
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(1)->requires_grad();
ctx->grad_requires_grad = inputs.at(0)->requires_grad();
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(1));
if (ctx->x_requires_grad) { ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0)); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before"));
ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("strides"));
ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation_rate"));
ctx->groups = JUST(composed_attrs.GetAttr<int32_t>("groups"));
return Maybe<void>::Ok();
}
Maybe<void> ConvFilterGradGrad::Apply(const ConvFilterGradGradCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
in_grads->resize(2);
size_t num_spatial_dims = ctx->kernel_size.size();
// first order forward: ConvND
// x * w = y ( * => convolution)
// first order backward:
// x_grad = y_grad * w.rot180 (y.shape * w.shape -> x.shape) call ConvDataGrad
// w_grad = x * y_grad (x.shape * y.shape -> w.shape) call ConvFilterGrad
// second order forward (first order backward): ConvFilterGrad
// x * y_grad = w_grad
// second order backward:
// x_grad_grad = out_grads_w * y_grad.rot180 (y.shape * w.shape -> x.shape) call ConvDataGrad
// grad_for_y_grad = x * out_grads_w (x.shape * w.shape -> y.shape) call ConvND
// x_grad_grad
if (ctx->x_requires_grad) {
const auto& grad = ctx->SavedTensors().at(ctx->grad_index);
const auto& x = ctx->SavedTensors().at(ctx->x_index);
in_grads->at(1) = JUST(functional::ConvDataGrad(
grad, out_grads.at(0), JUST(x->detach()), num_spatial_dims, ctx->kernel_size, ctx->strides,
ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));
}
// grad_for_y_grad
if (ctx->grad_requires_grad) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
const int32_t ndims = ctx->kernel_size.size();
const auto conv_op = (ndims == 1 ? functional::Conv1d
: (ndims == 2 ? functional::Conv2d
: (ndims == 3 ? functional::Conv3d : nullptr)));
CHECK_NOTNULL_OR_RETURN(conv_op); // NOLINT(maybe-need-error-msg)
in_grads->at(0) =
JUST(conv_op(x, out_grads.at(0), Optional<Tensor>(), ctx->strides, ctx->padding_before,
ctx->dilation_rate, ctx->groups, ctx->data_format));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("conv_data_grad", ConvDataGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("conv_filter_grad", ConvFilterGradGrad);
} // namespace one
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment