Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/autograd/autograd_mode.h"
namespace oneflow {
namespace autograd {
namespace {
bool* GetThreadLocalGradMode() {
static thread_local bool g_grad_mode = true;
return &g_grad_mode;
}
} // namespace
bool GradMode::is_enabled() { return *GetThreadLocalGradMode(); }
void GradMode::set_enabled(bool enabled) { *GetThreadLocalGradMode() = enabled; }
} // namespace autograd
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_
namespace oneflow {
namespace autograd {
struct GradMode {
static bool is_enabled();
static void set_enabled(bool enabled);
};
class AutoGradMode {
public:
AutoGradMode(bool enabled) : prev_mode_(GradMode::is_enabled()) {
GradMode::set_enabled(enabled);
}
~AutoGradMode() { GradMode::set_enabled(prev_mode_); }
bool prev_mode() const { return prev_mode_; }
private:
bool prev_mode_;
};
class NoGradGuard : public AutoGradMode {
public:
NoGradGuard() : AutoGradMode(false){};
};
} // namespace autograd
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct BaseActivationCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class BaseActivation : public OpExprGradFunction<BaseActivationCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(BaseActivationCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (ctx->requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
return Maybe<void>::Ok();
}
};
class Silu : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::SiluGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
class Mish : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::MishGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
class Selu : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::SeluGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
class Softsign : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::SoftSignGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
class GeLU : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::GeluGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
class HardSigmoid : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::HardSigmoidGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
struct HardShrinkCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
double lambd = 0.5;
};
class HardShrink : public OpExprGradFunction<HardShrinkCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(HardShrinkCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->lambd = JUST(composed_attrs.GetAttr<double>("lambd"));
ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0)));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const HardShrinkCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& y = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));
JUST(oneflow::VectorAt(*in_grads, 0)) =
JUST(functional::HardShrinkGrad(y, JUST(oneflow::VectorAt(out_grads, 0)), ctx->lambd));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
class HardSwish : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::HardSwishGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
// ===== Activation with parms ====
struct ReLUCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class ReLU : public OpExprGradFunction<ReLUCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(ReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (ctx->requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ReLUCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& y = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::ReluGrad(out_grads.at(0), y));
}
return Maybe<void>::Ok();
}
};
// ===== Activation with parms ====
struct LeakyReluCaptureState : public AutoGradCaptureState {
bool requires_grad;
float alpha;
};
class LeakyRelu : public OpExprGradFunction<LeakyReluCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(LeakyReluCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->alpha = JUST(composed_attrs.GetAttr<float>("alpha"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const LeakyReluCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::LeakyReluGrad(x, out_grads.at(0), ctx->alpha));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct SoftplusCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
double beta = 1.0;
double threshold = 20.0;
};
class Softplus : public OpExprGradFunction<SoftplusCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(SoftplusCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->beta = JUST(composed_attrs.GetAttr<double>("beta"));
ctx->threshold = JUST(composed_attrs.GetAttr<double>("threshold"));
ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SoftplusCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::SoftplusGrad(
x, JUST(oneflow::VectorAt(out_grads, 0)), ctx->beta, ctx->threshold));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct HardTanhCaptureState : public AutoGradCaptureState {
bool requires_grad;
double min_val;
double max_val;
};
class HardTanh : public OpExprGradFunction<HardTanhCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(HardTanhCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->min_val = JUST(composed_attrs.GetAttr<double>("min_val"));
ctx->max_val = JUST(composed_attrs.GetAttr<double>("max_val"));
ctx->SaveTensorForBackward(outputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const HardTanhCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& y = ctx->SavedTensors().at(0);
in_grads->at(0) =
JUST(functional::HardTanhGrad(y, out_grads.at(0), ctx->min_val, ctx->max_val));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct EluCaptureState : public AutoGradCaptureState {
bool requires_grad;
double alpha;
};
class Elu : public OpExprGradFunction<EluCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(EluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->alpha = JUST(composed_attrs.GetAttr<double>("alpha"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const EluCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::EluGrad(x, out_grads.at(0), ctx->alpha));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct CeluCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
double alpha = 1.0;
};
class Celu : public OpExprGradFunction<CeluCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(CeluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->alpha = JUST(composed_attrs.GetAttr<double>("alpha"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CeluCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::CeluGrad(x, out_grads.at(0), ctx->alpha));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct SoftShrinkCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
double alpha = 0.5;
};
class SoftShrink : public OpExprGradFunction<SoftShrinkCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(SoftShrinkCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->alpha = JUST(composed_attrs.GetAttr<double>("alpha"));
ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0)));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SoftShrinkCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& y = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));
JUST(oneflow::VectorAt(*in_grads, 0)) =
JUST(functional::SoftShrinkGrad(y, JUST(oneflow::VectorAt(out_grads, 0)), ctx->alpha));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct PReLUCaptureState : public AutoGradCaptureState {
bool input_requires_grad;
bool alpha_requires_grad;
};
class PReLU : public OpExprGradFunction<PReLUCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(PReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input
ctx->alpha_requires_grad = inputs.at(1)->requires_grad(); // alpha
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(inputs.at(1));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const PReLUCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0);
const auto& x = ctx->SavedTensors().at(0);
const auto& alpha = ctx->SavedTensors().at(1);
in_grads->resize(2);
if (ctx->input_requires_grad || ctx->alpha_requires_grad) {
const auto& grads = JUST(functional::PReluGrad(dy, x, alpha));
if (ctx->input_requires_grad) { in_grads->at(0) = grads->at(0); }
if (ctx->alpha_requires_grad) { in_grads->at(1) = grads->at(1); }
}
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
};
struct ThresholdCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
double threshold = 0.0;
};
class Threshold : public OpExprGradFunction<ThresholdCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(ThresholdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->threshold = JUST(composed_attrs.GetAttr<double>("threshold_val"));
ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ThresholdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));
JUST(oneflow::VectorAt(*in_grads, 0)) =
JUST(functional::ThresholdGrad(x, JUST(oneflow::VectorAt(out_grads, 0)), ctx->threshold));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("silu", Silu);
REGISTER_OP_EXPR_GRAD_FUNCTION("mish", Mish);
REGISTER_OP_EXPR_GRAD_FUNCTION("selu", Selu);
REGISTER_OP_EXPR_GRAD_FUNCTION("softsign", Softsign);
REGISTER_OP_EXPR_GRAD_FUNCTION("relu", ReLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("gelu", GeLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardsigmoid", HardSigmoid);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardshrink", HardShrink);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardswish", HardSwish);
REGISTER_OP_EXPR_GRAD_FUNCTION("leaky_relu", LeakyRelu);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardtanh", HardTanh);
REGISTER_OP_EXPR_GRAD_FUNCTION("elu", Elu);
REGISTER_OP_EXPR_GRAD_FUNCTION("celu", Celu);
REGISTER_OP_EXPR_GRAD_FUNCTION("prelu", PReLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("threshold", Threshold);
REGISTER_OP_EXPR_GRAD_FUNCTION("softplus", Softplus);
REGISTER_OP_EXPR_GRAD_FUNCTION("softshrink", SoftShrink);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct AdaptivePoolCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class AdaptivePoolNdGrad : public OpExprGradFunction<AdaptivePoolCaptureState> {
public:
using OpExprGradFunction<AdaptivePoolCaptureState>::Init;
Maybe<void> Init(const OpExpr& op, std::string mode, const int& ndims);
Maybe<void> Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const AdaptivePoolCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
std::string mode_;
int32_t ndims_;
};
Maybe<void> AdaptivePoolNdGrad::Init(const OpExpr& op, std::string mode, const int& ndims) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
mode_ = mode;
ndims_ = ndims;
return Maybe<void>::Ok();
}
Maybe<void> AdaptivePoolNdGrad::Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> AdaptivePoolNdGrad::Apply(const AdaptivePoolCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
in_grads->at(0) = JUST(functional::AdaptivePoolNdGrad(x, out_grads.at(0), mode_, ndims_));
return Maybe<void>::Ok();
}
class AdaptiveAvgPool1dGrad final : public AdaptivePoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return AdaptivePoolNdGrad::Init(op, "avg", 1); }
};
class AdaptiveAvgPool2dGrad final : public AdaptivePoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return AdaptivePoolNdGrad::Init(op, "avg", 2); }
};
class AdaptiveAvgPool3dGrad final : public AdaptivePoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return AdaptivePoolNdGrad::Init(op, "avg", 3); }
};
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool1d", AdaptiveAvgPool1dGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool2d", AdaptiveAvgPool2dGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_avg_pool3d", AdaptiveAvgPool3dGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
struct AddNCaptureState : public AutoGradCaptureState {
int32_t input_num;
std::vector<bool> requires_grad;
};
class AddN : public OpExprGradFunction<AddNCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(AddNCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
ctx->input_num = inputs.size();
ctx->requires_grad.resize(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
ctx->requires_grad[i] = inputs.at(i)->requires_grad();
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const AddNCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(ctx->input_num);
for (int i = 0; i < ctx->input_num; ++i) {
if (ctx->requires_grad.at(i)) { in_grads->at(i) = out_grads.at(0); }
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("add_n", AddN);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct AffineGridInterpState : public AutoGradCaptureState {
Shape size;
bool align_corners = false;
bool requires_grad = false;
};
class AffineGrid : public OpExprGradFunction<AffineGridInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(AffineGridInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad(); // theta
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->size = JUST(composed_attrs.GetAttr<Shape>("size"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const AffineGridInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
in_grads->at(0) =
JUST(functional::AffineGridGrad(out_grads.at(0), ctx->size, ctx->align_corners));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("affine_grid", AffineGrid);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct AsStridedCaptureState : public AutoGradCaptureState {
std::vector<int32_t> size;
std::vector<int32_t> stride;
int32_t storage_offset = 0;
bool requires_grad = false;
};
class AsStrided : public OpExprGradFunction<AsStridedCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(AsStridedCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const AsStridedCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> AsStrided::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> AsStrided::Capture(AsStridedCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("size"));
ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("stride"));
ctx->storage_offset = JUST(composed_attrs.GetAttr<int32_t>("storage_offset"));
return Maybe<void>::Ok();
}
Maybe<void> AsStrided::Apply(const AsStridedCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& input = ctx->SavedTensors().at(0);
std::vector<int32_t> size = ctx->size;
std::vector<int32_t> stride = ctx->stride;
int32_t storage_offset = ctx->storage_offset;
in_grads->at(0) =
JUST(functional::AsStridedGrad(out_grads.at(0), input, size, stride, storage_offset));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("as_strided", AsStrided);
} // namespace one
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
namespace {
struct AvgPoolCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
size_t input_index = 0;
std::string data_format;
std::vector<int32_t> padding;
std::vector<int32_t> kernel_size;
std::vector<int32_t> stride;
bool ceil_mode = false;
bool count_include_pad = false;
int32_t divisor_override = 0;
};
class AvgPoolNdGrad : public OpExprGradFunction<AvgPoolCaptureState> {
public:
virtual ~AvgPoolNdGrad() = default;
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(AvgPoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const AvgPoolCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> AvgPoolNdGrad::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> AvgPoolNdGrad::Capture(AvgPoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding"));
ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("stride"));
ctx->ceil_mode = JUST(composed_attrs.GetAttr<bool>("ceil_mode"));
ctx->count_include_pad = JUST(composed_attrs.GetAttr<bool>("count_include_pad"));
ctx->divisor_override = JUST(composed_attrs.GetAttr<int32_t>("divisor_override"));
return Maybe<void>::Ok();
}
Maybe<void> AvgPoolNdGrad::Apply(const AvgPoolCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
int32_t ndims = ctx->kernel_size.size();
const auto& input = ctx->SavedTensors().at(ctx->input_index);
in_grads->resize(1);
(*in_grads)[0] = JUST(functional::AvgPoolNdGrad(
input, out_grads[0], ndims, ctx->data_format, ctx->padding, ctx->kernel_size, ctx->stride,
ctx->ceil_mode, ctx->count_include_pad, ctx->divisor_override));
return Maybe<void>::Ok();
}
} // namespace
REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_1d", AvgPoolNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_2d", AvgPoolNdGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("avg_pool_3d", AvgPoolNdGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct BatchGatherCaptureState : public AutoGradCaptureState {
int64_t num_segments;
bool requires_grad;
};
class BatchGather : public OpExprGradFunction<BatchGatherCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> BatchGather::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> BatchGather::Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
const auto& in_shape = inputs.at(0)->shape();
const auto& indices_shape = inputs.at(1)->shape();
ctx->num_segments = in_shape->At(indices_shape->NumAxes() - 1);
ctx->SaveTensorForBackward(inputs.at(1));
return Maybe<void>::Ok();
}
Maybe<void> BatchGather::Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
in_grads->resize(2);
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
const auto& indices = ctx->SavedTensors().at(0);
in_grads->at(0) =
JUST(functional::UnsortedBatchSegmentSum(out_grads.at(0), indices, ctx->num_segments));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("batch_gather", BatchGather);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct BiasAddCaptureState : public AutoGradCaptureState {
bool input_requires_grad;
bool bias_requires_grad;
int32_t axis;
};
class BiasAdd : public OpExprGradFunction<BiasAddCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(BiasAddCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->bias_requires_grad = inputs.at(1)->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<int32_t>("axis"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const BiasAddCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const int64_t num_axes = out_grads.at(0)->shape()->NumAxes();
in_grads->resize(2);
if (ctx->bias_requires_grad) {
std::vector<int32_t> reduce_axes_vec;
reduce_axes_vec.reserve(num_axes);
for (int i = 0; i < num_axes; ++i) {
if (i != ctx->axis) { reduce_axes_vec.emplace_back(i); }
}
if (ctx->bias_requires_grad) {
in_grads->at(1) = JUST(functional::ReduceSum(out_grads.at(0), reduce_axes_vec, false));
}
}
if (ctx->input_requires_grad) { in_grads->at(0) = out_grads.at(0); }
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("bias_add", BiasAdd);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct BinaryCrossEntropyCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};
class BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BinaryCrossEntropyCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> BinaryCrossEntropy::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropy::Capture(BinaryCrossEntropyCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target
if (inputs.size() == 3) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight
}
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropy::Apply(const BinaryCrossEntropyCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0);
const auto& input = ctx->SavedTensors().at(0);
const auto& target = ctx->SavedTensors().at(1);
in_grads->resize(ctx->SavedTensors().size());
if (ctx->SavedTensors().size() == 3) {
const auto& weight = ctx->SavedTensors().at(2);
in_grads->at(0) = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, weight));
} else {
in_grads->at(0) = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, NullOpt));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy", BinaryCrossEntropy);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct BinaryCrossEntropyWithLogitsCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool has_pos_weight = false;
};
class BinaryCrossEntropyWithLogits
: public OpExprGradFunction<BinaryCrossEntropyWithLogitsCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BinaryCrossEntropyWithLogitsCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyWithLogitsCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> BinaryCrossEntropyWithLogits::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogits::Capture(BinaryCrossEntropyWithLogitsCaptureState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->has_pos_weight = JUST(composed_attrs.GetAttr<bool>("has_pos_weight"));
ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target
if (inputs.size() == 3) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight or pos_weight
}
if (inputs.size() == 4) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight
ctx->SaveTensorForBackward(inputs.at(3)); // pos_weight
}
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogits::Apply(const BinaryCrossEntropyWithLogitsCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0);
const auto& input = ctx->SavedTensors().at(0);
const auto& target = ctx->SavedTensors().at(1);
in_grads->resize(ctx->SavedTensors().size());
if (ctx->SavedTensors().size() == 3) {
if (ctx->has_pos_weight) {
const auto& pos_weight = ctx->SavedTensors().at(2);
in_grads->at(0) = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, NullOpt, pos_weight));
} else {
const auto& weight = ctx->SavedTensors().at(2);
in_grads->at(0) = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, NullOpt));
}
} else if (ctx->SavedTensors().size() == 4) {
const auto& weight = ctx->SavedTensors().at(2);
const auto& pos_weight = ctx->SavedTensors().at(3);
in_grads->at(0) = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, pos_weight));
} else {
in_grads->at(0) =
JUST(functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, NullOpt, NullOpt));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits", BinaryCrossEntropyWithLogits);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct BinaryCrossEntropyWithLogitsReduceMeanCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool has_pos_weight = false;
};
class BinaryCrossEntropyWithLogitsReduceMean
: public OpExprGradFunction<BinaryCrossEntropyWithLogitsReduceMeanCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be null. ";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture(
BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Apply(
const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out_grads size should be equal to 1. ";
const auto& dy = JUST(VectorAt(out_grads, 0));
const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0));
const auto& target = JUST(VectorAt(ctx->SavedTensors(), 1));
in_grads->resize(ctx->SavedTensors().size());
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossGrad(dy, input, target));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits_reduce_mean",
BinaryCrossEntropyWithLogitsReduceMean);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct BroadcastBinaryCaptureState : public AutoGradCaptureState {
int x_index = -1;
int y_index = -1;
int z_index = -1;
bool x_requires_grad = false;
bool y_requires_grad = false;
bool broadcast_x = false;
bool broadcast_y = false;
};
class BroadcastBinaryGrad : public OpExprGradFunction<BroadcastBinaryCaptureState> {
public:
BroadcastBinaryGrad() = default;
virtual ~BroadcastBinaryGrad() = default;
virtual Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->y_requires_grad = inputs.at(1)->requires_grad();
ctx->broadcast_x = (*inputs.at(0)->shape() != *outputs.at(0)->shape());
ctx->broadcast_y = (*inputs.at(1)->shape() != *outputs.at(0)->shape());
return SaveTensorForBackward(ctx, inputs, outputs);
}
protected:
virtual Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs) const = 0;
};
class BroadcastAdd : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->x_requires_grad) {
if (ctx->broadcast_x) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), x));
} else {
in_grads->at(0) = out_grads.at(0);
}
}
if (ctx->y_requires_grad) {
if (ctx->broadcast_y) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), y));
} else {
in_grads->at(1) = out_grads.at(0);
}
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad && ctx->broadcast_x) {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
}
if (ctx->y_requires_grad && ctx->broadcast_y) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_add", BroadcastAdd);
class BroadcastSub : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->x_requires_grad) {
if (ctx->broadcast_x) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(out_grads.at(0), x));
} else {
in_grads->at(0) = out_grads.at(0);
}
}
if (ctx->y_requires_grad) {
const auto& grad = JUST(functional::ScalarMul(out_grads.at(0), Scalar(-1.f), false));
if (ctx->broadcast_y) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(grad, y));
} else {
in_grads->at(1) = grad;
}
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad && ctx->broadcast_x) {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
}
if (ctx->y_requires_grad && ctx->broadcast_y) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_sub", BroadcastSub);
class BroadcastMul : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->x_requires_grad) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& x_grad = JUST(functional::Mul(out_grads.at(0), y));
if (ctx->broadcast_x) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x));
} else {
in_grads->at(0) = x_grad;
}
}
if (ctx->y_requires_grad) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
const auto& y_grad = JUST(functional::Mul(out_grads.at(0), x));
if (ctx->broadcast_y) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(y_grad, y));
} else {
in_grads->at(1) = y_grad;
}
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
if (ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); }
}
if (ctx->y_requires_grad) {
if (ctx->x_index == -1 /*x has not been saved*/) {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
}
if (ctx->broadcast_y && ctx->y_index == -1 /*y has not been saved*/) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_mul", BroadcastMul);
class BroadcastDiv : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->x_requires_grad) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& x_grad = JUST(functional::Div(out_grads.at(0), y));
if (ctx->broadcast_x) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x));
} else {
in_grads->at(0) = x_grad;
}
}
if (ctx->y_requires_grad) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& z = ctx->SavedTensors().at(ctx->z_index);
in_grads->at(1) = JUST(functional::DivGrad(out_grads.at(0), z, y));
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
if (ctx->broadcast_x) { ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); }
}
if (ctx->y_requires_grad) {
if (ctx->y_index == -1 /*y has not been saved*/) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
ctx->z_index = ctx->SaveTensorForBackward(outputs.at(0));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_div", BroadcastDiv);
class BroadcastPow : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& z = ctx->SavedTensors().at(ctx->z_index);
in_grads->resize(2);
if (ctx->x_requires_grad) {
in_grads->at(0) = JUST(functional::BroadcastPowXGrad(out_grads.at(0), x, y, z));
}
if (ctx->y_requires_grad) {
in_grads->at(1) = JUST(functional::BroadcastPowYGrad(out_grads.at(0), x, y, z));
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
ctx->z_index = ctx->SaveTensorForBackward(outputs.at(0));
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_pow", BroadcastPow);
class BroadcastMinMax : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& out_shape = *(out_grads.at(0)->shape());
in_grads->resize(2);
if (ctx->x_requires_grad || ctx->y_requires_grad) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
const auto& y = ctx->SavedTensors().at(ctx->y_index);
auto broad_x_ = x;
auto broad_y_ = y;
if (ctx->broadcast_x) {
const auto& x_shape = *(x->shape());
const Shape& left_extended_x_shape =
CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes());
if (left_extended_x_shape == out_shape) {
broad_x_ = JUST(functional::ReshapeLike(x, JUST(VectorAt(out_grads, 0))));
} else {
const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> x_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_x_ = JUST(functional::BroadcastLike(x, JUST(VectorAt(out_grads, 0)), x_axis));
}
}
if (ctx->broadcast_y) {
const auto& y_shape = *(y->shape());
const Shape& left_extended_y_shape =
CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes());
if (left_extended_y_shape == out_shape) {
broad_y_ = JUST(functional::ReshapeLike(y, JUST(VectorAt(out_grads, 0))));
} else {
const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> y_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_y_ = JUST(functional::BroadcastLike(y, JUST(VectorAt(out_grads, 0)), y_axis));
}
}
const auto& broad_grads =
JUST(elementwise_grad_functor_(out_grads.at(0), broad_x_, broad_y_));
if (ctx->x_requires_grad) {
if (ctx->broadcast_x) {
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(broad_grads->at(0), x));
} else {
in_grads->at(0) = broad_grads->at(0);
}
}
if (ctx->y_requires_grad) {
if (ctx->broadcast_y) {
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(broad_grads->at(1), y));
} else {
in_grads->at(1) = broad_grads->at(1);
}
}
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad || ctx->y_requires_grad) {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
return Maybe<void>::Ok();
}
std::function<Maybe<TensorTuple>(const std::shared_ptr<Tensor>&, const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)>
elementwise_grad_functor_;
};
class BroadcastMinimum : public BroadcastMinMax {
public:
Maybe<void> Init(const OpExpr& op) override {
JUST(BroadcastMinMax::Init(op));
elementwise_grad_functor_ = functional::ElementwiseMinGrad;
return Maybe<void>::Ok();
}
};
class BroadcastMaximum : public BroadcastMinMax {
public:
Maybe<void> Init(const OpExpr& op) override {
JUST(BroadcastMinMax::Init(op));
elementwise_grad_functor_ = functional::ElementwiseMaxGrad;
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_minimum", BroadcastMinimum);
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_maximum", BroadcastMaximum);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
struct BroadcastFModCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class BroadcastFMod : public OpExprGradFunction<BroadcastFModCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(BroadcastFModCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const BroadcastFModCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_fmod", BroadcastFMod);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct BroadCastLikeCaptureState : public AutoGradCaptureState {
bool requires_grad;
size_t input_index;
std::vector<int32_t> broadcast_axes;
};
class BroadCastLike : public OpExprGradFunction<BroadCastLikeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BroadCastLikeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BroadCastLikeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> BroadCastLike::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> BroadCastLike::Capture(BroadCastLikeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->broadcast_axes = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("broadcast_axes"));
ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> BroadCastLike::Apply(const BroadCastLikeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& x = ctx->SavedTensors().at(ctx->input_index);
in_grads->resize(2);
in_grads->at(0) = JUST(functional::ReduceSumLike(out_grads.at(0), x, ctx->broadcast_axes));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_like", BroadCastLike);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/symbol.h"
namespace oneflow {
namespace one {
struct CastCaptureState : public AutoGradCaptureState {
Symbol<DType> dtype;
};
class Cast : public OpExprGradFunction<CastCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Capture(CastCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
ctx->dtype = inputs.at(0)->dtype();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CastCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(1);
(*in_grads)[0] = JUST(functional::Cast(out_grads[0], ctx->dtype, /*pin_memory=*/false));
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("cast", Cast);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ClipByScalarCaptureState : public AutoGradCaptureState {
bool requires_grad;
Scalar min;
Scalar max;
};
class ClipByScalar : public OpExprGradFunction<ClipByScalarCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(ClipByScalarCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
if (IsFloatingDataType(inputs.at(0)->dtype()->data_type())) {
ctx->min = Scalar(JUST(composed_attrs.GetAttr<double>("floating_min")));
ctx->max = Scalar(JUST(composed_attrs.GetAttr<double>("floating_max")));
} else if (IsIntegralDataType(inputs.at(0)->dtype()->data_type())) {
ctx->min = Scalar(JUST(composed_attrs.GetAttr<int64_t>("integral_min")));
ctx->max = Scalar(JUST(composed_attrs.GetAttr<int64_t>("integral_max")));
} else {
UNIMPLEMENTED_THEN_RETURN() << "Data type is not floating or integral type.";
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ClipByScalarCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::ClampGrad(out_grads.at(0), x, ctx->min, ctx->max));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("clip_by_scalar", ClipByScalar);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ClipByScalarMaxCaptureState : public AutoGradCaptureState {
bool requires_grad;
Scalar max;
};
class ClipByScalarMax : public OpExprGradFunction<ClipByScalarMaxCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(ClipByScalarMaxCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
if (IsFloatingDataType(inputs.at(0)->dtype()->data_type())) {
ctx->max = Scalar(JUST(composed_attrs.GetAttr<double>("floating_max")));
} else if (IsIntegralDataType(inputs.at(0)->dtype()->data_type())) {
ctx->max = Scalar(JUST(composed_attrs.GetAttr<int64_t>("integral_max")));
} else {
UNIMPLEMENTED_THEN_RETURN() << "Data type is not floating or integral type.";
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ClipByScalarMaxCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::ClampGrad(out_grads.at(0), x, /*min=*/NullOpt, ctx->max));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("clip_by_scalar_max", ClipByScalarMax);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ClipByScalarMinCaptureState : public AutoGradCaptureState {
bool requires_grad;
Scalar min;
};
class ClipByScalarMin : public OpExprGradFunction<ClipByScalarMinCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(ClipByScalarMinCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
if (IsFloatingDataType(inputs.at(0)->dtype()->data_type())) {
ctx->min = Scalar(JUST(composed_attrs.GetAttr<double>("floating_min")));
} else if (IsIntegralDataType(inputs.at(0)->dtype()->data_type())) {
ctx->min = Scalar(JUST(composed_attrs.GetAttr<int64_t>("integral_min")));
} else {
UNIMPLEMENTED_THEN_RETURN() << "Data type is not floating or integral type.";
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ClipByScalarMinCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::ClampGrad(out_grads.at(0), x, ctx->min,
/*max=*/NullOpt));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("clip_by_scalar_min", ClipByScalarMin);
} // namespace one
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment