Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/optional.h"
namespace oneflow {
namespace one {
struct GlobalToGlobalState : public AutoGradCaptureState {
Symbol<ParallelDesc> parallel_desc;
Symbol<NdSbp> nd_sbp;
};
class GlobalToGlobalGradFunction : public OpExprGradFunction<GlobalToGlobalState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const GlobalToGlobalOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
grad_nd_sbp_ = fw_op_expr->grad_nd_sbp();
return Maybe<void>::Ok();
}
Maybe<void> Capture(GlobalToGlobalState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->parallel_desc = JUST(inputs.at(0)->parallel_desc());
ctx->nd_sbp = JUST(inputs.at(0)->nd_sbp());
return Maybe<void>::Ok();
}
Maybe<void> Apply(const GlobalToGlobalState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& out_grad = out_grads.at(0);
CHECK_OR_RETURN(out_grad->is_global())
<< Error::RuntimeError()
<< "Expected global tensor for global_to_global but got local tensor";
in_grads->resize(1);
const auto& grad_nd_sbp = grad_nd_sbp_.value_or(JUST(out_grad->nd_sbp()));
const auto& grad_sbp_list = JUST(GetSbpList(grad_nd_sbp));
if (LazyMode::is_enabled()) {
(*in_grads)[0] = JUST(one::functional::ToGlobal(out_grad, ctx->parallel_desc, *grad_sbp_list,
{}, /* check_meta */ false, /*copy=*/false));
} else {
const auto& grad_grad_sbp_list = JUST(GetSbpList(ctx->nd_sbp));
(*in_grads)[0] = JUST(one::functional::ToGlobal(out_grad, ctx->parallel_desc, *grad_sbp_list,
*grad_grad_sbp_list, /* check_meta */ false,
/*copy=*/false));
}
return Maybe<void>::Ok();
}
private:
Optional<Symbol<NdSbp>> grad_nd_sbp_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("global_to_global", GlobalToGlobalGradFunction);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct GradAccRepeatCaptureState : public AutoGradCaptureState {
int32_t repeat_num = 1;
};
class GradAccRepeat : public OpExprGradFunction<GradAccRepeatCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(GradAccRepeatCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const GradAccRepeatCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> GradAccRepeat::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> GradAccRepeat::Capture(GradAccRepeatCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->repeat_num = JUST(composed_attrs.GetAttr<int32_t>("repeat_num"));
return Maybe<void>::Ok();
}
Maybe<void> GradAccRepeat::Apply(const GradAccRepeatCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
(*in_grads)[0] = JUST(functional::GradAccCollect(out_grads[0], ctx->repeat_num));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("repeat", GradAccRepeat);
struct GradAccCollectCaptureState : public AutoGradCaptureState {
int32_t max_acc_num = 1;
};
class GradAccCollect : public OpExprGradFunction<GradAccCollectCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(GradAccCollectCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const GradAccCollectCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> GradAccCollect::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> GradAccCollect::Capture(GradAccCollectCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->max_acc_num = JUST(composed_attrs.GetAttr<int32_t>("max_acc_num"));
return Maybe<void>::Ok();
}
Maybe<void> GradAccCollect::Apply(const GradAccCollectCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
(*in_grads)[0] = JUST(functional::GradAccRepeat(out_grads[0], ctx->max_acc_num));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("acc", GradAccCollect);
struct GradAccPackCaptureState : public AutoGradCaptureState {
int32_t pack_num = 1;
};
class GradAccPack : public OpExprGradFunction<GradAccPackCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(GradAccPackCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const GradAccPackCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> GradAccPack::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> GradAccPack::Capture(GradAccPackCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->pack_num = JUST(composed_attrs.GetAttr<int32_t>("pack_num"));
return Maybe<void>::Ok();
}
Maybe<void> GradAccPack::Apply(const GradAccPackCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
(*in_grads)[0] = JUST(functional::GradAccUnpack(out_grads[0], ctx->pack_num));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("pack", GradAccPack);
struct GradAccUnpackCaptureState : public AutoGradCaptureState {
int32_t unpack_num = 1;
};
class GradAccUnpack : public OpExprGradFunction<GradAccUnpackCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(GradAccUnpackCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const GradAccUnpackCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> GradAccUnpack::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> GradAccUnpack::Capture(GradAccUnpackCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->unpack_num = JUST(composed_attrs.GetAttr<int32_t>("unpack_num"));
return Maybe<void>::Ok();
}
Maybe<void> GradAccUnpack::Apply(const GradAccUnpackCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
(*in_grads)[0] = JUST(functional::GradAccPack(out_grads[0], ctx->unpack_num));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("unpack", GradAccUnpack);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/job/lazy_mode.h"
namespace oneflow {
namespace one {
struct GraphFeedAndFetchCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};
class GraphFeedAndFetch : public OpExprGradFunction<GraphFeedAndFetchCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(GraphFeedAndFetchCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const GraphFeedAndFetchCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("graph_feed_and_fetch", GraphFeedAndFetch);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct GroupNormCaptureState : public AutoGradCaptureState {
double epsilon = 1e-5;
bool x_requires_grad = true;
bool affine = true;
int32_t num_groups = 1;
size_t x_index = 0;
size_t mean_index = 1;
size_t inv_variance_index = 2;
size_t gamma_index = 3;
std::string data_format;
std::string activation;
};
class GroupNorm : public OpExprGradFunction<GroupNormCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(GroupNormCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const GroupNormCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
std::string op_name_;
};
Maybe<void> GroupNorm::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
op_name_ = fw_op_expr->op_name();
return Maybe<void>::Ok();
}
Maybe<void> GroupNorm::Capture(GroupNormCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->affine = JUST(composed_attrs.GetAttr<bool>("affine"));
ctx->epsilon = JUST(composed_attrs.GetAttr<double>("epsilon"));
ctx->num_groups = JUST(composed_attrs.GetAttr<int32_t>("num_groups"));
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->activation = JUST(composed_attrs.GetAttr<std::string>("activation"));
if (ctx->affine) {
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
} else {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
}
CHECK_EQ_OR_RETURN(outputs.size(), 3); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
if (ctx->x_requires_grad || ctx->affine) {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
ctx->mean_index = ctx->SaveTensorForBackward(outputs.at(1));
ctx->inv_variance_index = ctx->SaveTensorForBackward(outputs.at(2));
if (ctx->affine) {
ctx->gamma_index = ctx->SaveTensorForBackward(inputs.at(1)); // save gamma.
}
}
return Maybe<void>::Ok();
}
Maybe<void> GroupNorm::Apply(const GroupNormCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(ctx->data_format, "channels_first");
CHECK_EQ_OR_RETURN(ctx->activation, "none");
const auto& saved_tensors = ctx->SavedTensors();
if (ctx->affine) {
in_grads->resize(3);
} else {
in_grads->resize(1);
}
const auto& dy = out_grads.at(0);
const auto& x = saved_tensors.at(ctx->x_index);
const auto& mean = saved_tensors.at(ctx->mean_index);
const auto& inv_variance = saved_tensors.at(ctx->inv_variance_index);
if (ctx->affine) {
const auto& results = JUST(functional::GroupNormParamGrad(dy, x, mean, inv_variance));
in_grads->at(1) = results->at(0); // For gamma.
in_grads->at(2) = results->at(1); // For beta.
}
if (ctx->x_requires_grad) {
if (ctx->affine) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
in_grads->at(0) = JUST(functional::GroupNormGrad(dy, x, mean, inv_variance, gamma,
ctx->num_groups, ctx->epsilon));
} else {
in_grads->at(0) = JUST(functional::GroupNormGrad(dy, x, mean, inv_variance, NullOpt,
ctx->num_groups, ctx->epsilon));
}
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("group_norm", GroupNorm);
} // namespace one
} // namespace oneflow
...@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and ...@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/job/lazy_mode.h"
namespace oneflow { namespace oneflow {
namespace one { namespace one {
...@@ -37,7 +38,15 @@ class Identity : public OpExprGradFunction<IdentityCaptureState> { ...@@ -37,7 +38,15 @@ class Identity : public OpExprGradFunction<IdentityCaptureState> {
TensorTuple* in_grads) const override { TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1); in_grads->resize(1);
if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); } if (ctx->requires_grad) {
if (LazyMode::is_enabled()) {
// requires an intermediate node to avoid redundant memory copy or commnet
// communication in lazy mode
in_grads->at(0) = JUST(functional::Identity(out_grads.at(0)));
} else {
in_grads->at(0) = out_grads.at(0);
}
}
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
}; };
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct InvCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};
class Inv : public OpExprGradFunction<InvCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(InvCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
if (ctx->requires_grad) { ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const InvCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (ctx->requires_grad) {
const auto& output = JUST(VectorAt(ctx->SavedTensors(), 0));
const auto& dy = JUST(VectorAt(out_grads, 0));
JUST(VectorAt(*in_grads, 0)) = JUST(functional::Negative(JUST(functional::MatMul(
output, JUST(functional::MatMul(dy, output, false, true, 1.0)), true, false, 1.0))));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("inv", Inv);
} // namespace one
} // namespace oneflow
...@@ -20,7 +20,8 @@ namespace oneflow { ...@@ -20,7 +20,8 @@ namespace oneflow {
namespace one { namespace one {
struct KLDivLossCaptureState : public AutoGradCaptureState { struct KLDivLossCaptureState : public AutoGradCaptureState {
bool requires_grad = false; bool input_requires_grad = false;
bool target_requires_grad = false;
bool log_target = false; bool log_target = false;
}; };
...@@ -44,25 +45,31 @@ Maybe<void> KLDivLoss::Init(const OpExpr& op) { ...@@ -44,25 +45,31 @@ Maybe<void> KLDivLoss::Init(const OpExpr& op) {
} }
Maybe<void> KLDivLoss::Capture(KLDivLossCaptureState* ctx, const TensorTuple& inputs, Maybe<void> KLDivLoss::Capture(KLDivLossCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const { const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad(); CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
if (!ctx->requires_grad) { return Maybe<void>::Ok(); } ctx->input_requires_grad = inputs[0]->requires_grad();
ctx->target_requires_grad = inputs[1]->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_); ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->log_target = JUST(composed_attrs.GetAttr<bool>("log_target")); ctx->log_target = JUST(composed_attrs.GetAttr<bool>("log_target"));
ctx->SaveTensorForBackward(inputs.at(0)); // input ctx->SaveTensorForBackward(inputs[0]); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target ctx->SaveTensorForBackward(inputs[1]); // target
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<void> KLDivLoss::Apply(const KLDivLossCaptureState* ctx, const TensorTuple& out_grads, Maybe<void> KLDivLoss::Apply(const KLDivLossCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const { TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(), 2); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads[0];
const auto& input = ctx->SavedTensors()[0];
const auto& target = ctx->SavedTensors()[1];
in_grads->resize(2);
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) if (ctx->input_requires_grad) {
const auto& dy = out_grads.at(0); (*in_grads)[0] = JUST(functional::KLDivLossGrad(dy, input, target, ctx->log_target));
const auto& input = ctx->SavedTensors().at(0); }
const auto& target = ctx->SavedTensors().at(1); if (ctx->target_requires_grad) {
in_grads->resize(ctx->SavedTensors().size()); (*in_grads)[1] = JUST(functional::KLDivLossTargetGrad(dy, input, target, ctx->log_target));
in_grads->at(0) = JUST(functional::KLDivLossGrad(dy, input, target, ctx->log_target)); }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
......
...@@ -108,10 +108,10 @@ Maybe<void> LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple ...@@ -108,10 +108,10 @@ Maybe<void> LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple
std::shared_ptr<Tensor> inv_variance = saved_tensors.at(ctx->inv_variance_index); std::shared_ptr<Tensor> inv_variance = saved_tensors.at(ctx->inv_variance_index);
if (ctx->has_affine) { if (ctx->has_affine) {
// Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 // Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,
// begin_params_axis, Double epsilon). // Int64 begin_params_axis)
const auto& results = JUST( const auto& results =
functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis, ctx->epsilon)); JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));
in_grads->at(1) = results->at(0); // For gamma. in_grads->at(1) = results->at(0); // For gamma.
in_grads->at(2) = results->at(1); // For beta. in_grads->at(2) = results->at(1); // For beta.
} }
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/just.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional_api.yaml.h"
namespace oneflow {
namespace one {
struct LinalgCrossCaptureState : public AutoGradCaptureState {
int64_t dim = -1;
bool input_requires_grad = false;
bool other_requires_grad = false;
};
class LinalgCross : public OpExprGradFunction<LinalgCrossCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(LinalgCrossCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const LinalgCrossCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> LinalgCross::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> LinalgCross::Capture(LinalgCrossCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->other_requires_grad = inputs.at(1)->requires_grad();
if (ctx->input_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }
if (ctx->other_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int64_t>("dim"));
return Maybe<void>::Ok();
}
Maybe<void> LinalgCross::Apply(const LinalgCrossCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
in_grads->resize(ctx->SavedTensors().size());
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
if (ctx->input_requires_grad) {
in_grads->at(0) =
JUST(functional::LinalgCross(ctx->SavedTensors().at(0), out_grads.at(0), ctx->dim));
}
if (ctx->other_requires_grad) {
in_grads->at(1) = JUST(functional::LinalgCross(
out_grads.at(0), ctx->SavedTensors().at(ctx->input_requires_grad ? 1 : 0), ctx->dim));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("linalg_cross", LinalgCross);
} // namespace one
} // namespace oneflow
\ No newline at end of file
...@@ -38,39 +38,23 @@ class LogSoftmax : public OpExprGradFunction<LogSoftmaxCaptureState> { ...@@ -38,39 +38,23 @@ class LogSoftmax : public OpExprGradFunction<LogSoftmaxCaptureState> {
std::shared_ptr<OpExpr> grad_op_; std::shared_ptr<OpExpr> grad_op_;
}; };
Maybe<void> LogSoftmax::Init(const OpExpr& op) { Maybe<void> LogSoftmax::Init(const OpExpr& op) { return Maybe<void>::Ok(); }
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
const std::string& op_name = fw_op_expr->op_name();
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
grad_op_ = JUST(one::OpBuilder("log_softmax_grad", GradientOpName(op_name))
.Input("prob")
.Input("dy")
.Output("dx")
.Build());
return Maybe<void>::Ok();
}
Maybe<void> LogSoftmax::Capture(LogSoftmaxCaptureState* ctx, const TensorTuple& inputs, Maybe<void> LogSoftmax::Capture(LogSoftmaxCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const { const TensorTuple& outputs, const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad(); ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) return Maybe<void>::Ok();
ctx->SaveTensorForBackward(outputs.at(0)); ctx->SaveTensorForBackward(outputs.at(0));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<void> LogSoftmax::Apply(const LogSoftmaxCaptureState* ctx, const TensorTuple& out_grads, Maybe<void> LogSoftmax::Apply(const LogSoftmaxCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const { TensorTuple* in_grads) const {
if (!ctx->requires_grad) return Maybe<void>::Ok();
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0); const auto& dy = out_grads.at(0);
const auto& prob = ctx->SavedTensors().at(0); const auto& y = ctx->SavedTensors().at(0);
in_grads->resize(1); in_grads->resize(1);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {prob, dy})); in_grads->at(0) = JUST(functional::LogSoftmaxGrad(dy, y));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
......
...@@ -30,7 +30,7 @@ typedef Maybe<one::Tensor> (*UnaryBwFunc)(const std::shared_ptr<one::Tensor>&, ...@@ -30,7 +30,7 @@ typedef Maybe<one::Tensor> (*UnaryBwFunc)(const std::shared_ptr<one::Tensor>&,
const std::shared_ptr<one::Tensor>&); const std::shared_ptr<one::Tensor>&);
template<UnaryBwFunc BwFunc> template<UnaryBwFunc BwFunc>
class UnaryMathOp : public OpExprGradFunction<UnaryMathCaptureState> { class UnaryMathBwdWithDyXOp : public OpExprGradFunction<UnaryMathCaptureState> {
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); } Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs, Maybe<void> Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs,
...@@ -52,20 +52,96 @@ class UnaryMathOp : public OpExprGradFunction<UnaryMathCaptureState> { ...@@ -52,20 +52,96 @@ class UnaryMathOp : public OpExprGradFunction<UnaryMathCaptureState> {
std::shared_ptr<OpExpr> grad_op_; std::shared_ptr<OpExpr> grad_op_;
}; };
#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_CLASS(op_type_name, op_cls) \ template<UnaryBwFunc BwFunc>
class op_cls##Cls final : public UnaryMathOp<functional::op_cls##Grad> {}; \ class UnaryMathBwdWithDyYOp : public OpExprGradFunction<UnaryMathCaptureState> {
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->SaveTensorForBackward(outputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->x_requires_grad) { return Maybe<void>::Ok(); }
const auto& y = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(BwFunc(y, out_grads.at(0)));
return Maybe<void>::Ok();
}
protected:
std::shared_ptr<OpExpr> grad_op_;
};
class UnaryMathBwdWithFillZeroOp : public OpExprGradFunction<UnaryMathCaptureState> {
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->x_requires_grad) { return Maybe<void>::Ok(); }
in_grads->at(0) = JUST(functional::ZerosLike(out_grads[0]));
return Maybe<void>::Ok();
}
protected:
std::shared_ptr<OpExpr> grad_op_;
};
#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_X_CLASS(op_type_name, op_cls) \
class op_cls##Cls final : public UnaryMathBwdWithDyXOp<functional::op_cls##Grad> {}; \
REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##Cls); REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##Cls);
OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_CLASS, MATH_UNARY_ELEMENTWISE_FUNC_SEQ); OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_X_CLASS,
OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_CLASS, MATH_UNARY_ELEMENTWISE_PRIMITIVE_FUNC_BWD_WITH_DY_X_SEQ);
OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_X_CLASS,
OF_PP_MAKE_TUPLE_SEQ("tanh", Tanh)); OF_PP_MAKE_TUPLE_SEQ("tanh", Tanh));
// higher order derivative #undef INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_X_CLASS
OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_CLASS,
OF_PP_MAKE_TUPLE_SEQ("sin_grad", SinGrad)); #define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS(op_type_name, op_cls) \
OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_CLASS, class op_cls##Cls final : public UnaryMathBwdWithDyYOp<functional::op_cls##Grad> {}; \
OF_PP_MAKE_TUPLE_SEQ("cos_grad", CosGrad)); REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##Cls);
OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS,
MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_DY_Y_SEQ);
#undef INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_DY_Y_CLASS
#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_FILL_CLASS(op_type_name, op_cls) \
class op_cls##Cls final : public UnaryMathBwdWithDyYOp<functional::op_cls##Grad> {}; \
REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, UnaryMathBwdWithFillZeroOp);
OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_FILL_CLASS,
MATH_UNARY_ELEMENTWISE_FUNC_BWD_WITH_FILL_SEQ);
#undef INSTANTIAT_AND_REGISTER_UNARY_MATHOP_WITH_FILL_CLASS
class NegativeOp : public OpExprGradFunction<UnaryMathCaptureState> {
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->x_requires_grad) { return Maybe<void>::Ok(); }
in_grads->at(0) = JUST(functional::Negative(out_grads[0]));
return Maybe<void>::Ok();
}
protected:
std::shared_ptr<OpExpr> grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("negative", NegativeOp);
#undef INSTANTIAT_AND_REGISTER_UNARY_MATHOP_CLASS
} // namespace one } // namespace one
} // namespace oneflow } // namespace oneflow
...@@ -18,6 +18,7 @@ limitations under the License. ...@@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow { namespace oneflow {
namespace one { namespace one {
...@@ -102,40 +103,194 @@ Maybe<void> Matmul::Apply(const MatmulCaptureState* ctx, const TensorTuple& out_ ...@@ -102,40 +103,194 @@ Maybe<void> Matmul::Apply(const MatmulCaptureState* ctx, const TensorTuple& out_
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
class BroadcastMatmul : public Matmul { struct BroadcastMatmulCaptureState : public AutoGradCaptureState {
bool transpose_a = false;
bool transpose_b = false;
double alpha = 1.0;
bool requires_grad_a = true;
bool requires_grad_b = true;
size_t a_index = 0;
size_t b_index = 1;
bool broadcast_a = false;
bool broadcast_b = false;
int64_t b_num_axes = 0;
};
class BroadcastMatmul : public OpExprGradFunction<BroadcastMatmulCaptureState> {
public: public:
Maybe<void> Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads, Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BroadcastMatmulCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BroadcastMatmulCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override; TensorTuple* in_grads) const override;
protected:
AttrMap base_attrs_;
}; };
Maybe<void> BroadcastMatmul::Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads, Maybe<void> BroadcastMatmul::Init(const OpExpr& op) {
TensorTuple* in_grads) const { const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be null. ";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> BroadcastMatmul::Capture(BroadcastMatmulCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad_a = JUST(VectorAt(inputs, 0))->requires_grad();
ctx->requires_grad_b = JUST(VectorAt(inputs, 1))->requires_grad();
if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); } if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto a_shape = JUST(VectorAt(inputs, 0))->shape();
const auto b_shape = JUST(VectorAt(inputs, 1))->shape();
const int64_t a_num_axes = a_shape->NumAxes();
const int64_t b_num_axes = b_shape->NumAxes();
const size_t num_max_batch_dims = std::max(a_num_axes, b_num_axes) - 2;
auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const Shape& shape_dim) {
const int64_t num_batch_dims = num_dims - 2;
const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims;
return [num_padding_dims, shape_dim](size_t index) {
return index < num_padding_dims ? 1 : shape_dim.At(index - num_padding_dims);
};
};
auto GetABatchDim = MakeGetBatchDim(a_num_axes, *a_shape);
auto GetBBatchDim = MakeGetBatchDim(b_num_axes, *b_shape);
bool broadcast_a = false;
bool broadcast_b = false;
for (int32_t i = 0; i < num_max_batch_dims; i++) {
if (GetABatchDim(i) < GetBBatchDim(i) || a_num_axes < b_num_axes) {
broadcast_a = true;
break;
}
}
for (int32_t i = 0; i < num_max_batch_dims; i++) {
if (GetBBatchDim(i) < GetABatchDim(i) || b_num_axes < a_num_axes) {
broadcast_b = true;
break;
}
}
if (b_num_axes == 2 && !ctx->transpose_a) {
// In this case, we can directly use `broadcast_matmul_grad_b` OP to generate Grad instead of
// broadcast_matmul+reduce_sum_like.
broadcast_b = false;
}
ctx->broadcast_a = broadcast_a;
ctx->broadcast_b = broadcast_b;
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->transpose_a = JUST(composed_attrs.GetAttr<bool>("transpose_a"));
ctx->transpose_b = JUST(composed_attrs.GetAttr<bool>("transpose_b"));
ctx->alpha = JUST(composed_attrs.GetAttr<double>("alpha"));
if (ctx->requires_grad_a) {
ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // input b
if (broadcast_a) {
ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input a
}
}
if (ctx->requires_grad_b) {
ctx->b_num_axes = JUST(VectorAt(inputs, 1))->shape()->NumAxes();
ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input a
if (broadcast_b) {
ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // input b
}
}
return Maybe<void>::Ok();
}
Maybe<void> BroadcastMatmul::Apply(const BroadcastMatmulCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "Out grad size should be equal to 1. ";
in_grads->resize(2); in_grads->resize(2);
const auto out_shape = JUST(VectorAt(out_grads, 0))->shape();
const int64_t out_num_axes = out_shape->NumAxes();
const size_t num_max_batch_dims = out_num_axes - 2;
auto MakeGetBatchDim = [num_max_batch_dims](size_t num_dims, const Shape& shape_dim) {
const int64_t num_batch_dims = num_dims - 2;
const int64_t num_padding_dims = num_max_batch_dims - num_batch_dims;
return [num_padding_dims, shape_dim](size_t index) {
return index < num_padding_dims ? 1 : shape_dim.At(index - num_padding_dims);
};
};
auto GetOutBatchDim = MakeGetBatchDim(out_num_axes, *out_shape);
if (ctx->requires_grad_a) { if (ctx->requires_grad_a) {
std::shared_ptr<Tensor> broadcast_grad_a;
const auto& input_b = ctx->SavedTensors().at(ctx->b_index); const auto& input_b = ctx->SavedTensors().at(ctx->b_index);
if (ctx->transpose_a) { if (ctx->transpose_a) {
in_grads->at(0) = broadcast_grad_a = JUST(functional::MatMul(input_b, JUST(VectorAt(out_grads, 0)),
JUST(functional::MatMul(input_b, out_grads.at(0), ctx->transpose_b, true, ctx->alpha)); ctx->transpose_b, true, ctx->alpha));
} else { } else {
in_grads->at(0) = JUST( broadcast_grad_a = JUST(functional::MatMul(JUST(VectorAt(out_grads, 0)), input_b, false,
functional::MatMul(out_grads.at(0), input_b, false, !(ctx->transpose_b), ctx->alpha)); !(ctx->transpose_b), ctx->alpha));
}
if (ctx->broadcast_a) {
const auto& input_a = JUST(VectorAt(ctx->SavedTensors(), ctx->a_index));
const auto a_shape = input_a->shape();
const int64_t a_num_axes = a_shape->NumAxes();
std::vector<int32_t> a_reduce_vec;
auto GetABatchDim = MakeGetBatchDim(a_num_axes, *a_shape);
const int64_t a_out_num_dim_differ = out_num_axes - a_num_axes;
for (int32_t i = 0; i < out_num_axes - 2; i++) {
if (GetOutBatchDim(i) > GetABatchDim(i)
|| (GetOutBatchDim(i) == 1 && i < a_out_num_dim_differ)) {
a_reduce_vec.push_back(i);
}
}
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::ReduceSumLike(broadcast_grad_a, input_a, a_reduce_vec));
} else {
JUST(VectorAt(*in_grads, 0)) = broadcast_grad_a;
} }
} }
if (ctx->requires_grad_b) { if (ctx->requires_grad_b) {
const auto& input_a = ctx->SavedTensors().at(ctx->a_index); const auto& input_a = ctx->SavedTensors().at(ctx->a_index);
if (ctx->transpose_b) { if (ctx->b_num_axes == 2 && !ctx->transpose_a) {
in_grads->at(1) = if (ctx->transpose_b) {
JUST(functional::BroadcastMatmulGradB(out_grads.at(0), input_a, ctx->alpha)); JUST(VectorAt(*in_grads, 1)) = JUST(
functional::BroadcastMatmulGradB(JUST(VectorAt(out_grads, 0)), input_a, ctx->alpha));
} else {
JUST(VectorAt(*in_grads, 1)) = JUST(
functional::BroadcastMatmulGradB(input_a, JUST(VectorAt(out_grads, 0)), ctx->alpha));
}
} else { } else {
in_grads->at(1) = std::shared_ptr<Tensor> broadcast_grad_b;
JUST(functional::BroadcastMatmulGradB(input_a, out_grads.at(0), ctx->alpha)); if (ctx->transpose_b) {
broadcast_grad_b = JUST(functional::MatMul(JUST(VectorAt(out_grads, 0)), input_a, true,
ctx->transpose_a, ctx->alpha));
} else {
broadcast_grad_b = JUST(functional::MatMul(input_a, JUST(VectorAt(out_grads, 0)),
!ctx->transpose_a, false, ctx->alpha));
}
if (ctx->broadcast_b) {
const auto& input_b = JUST(VectorAt(ctx->SavedTensors(), ctx->b_index));
const auto b_shape = input_b->shape();
std::vector<int32_t> b_reduce_vec;
auto GetBBatchDim = MakeGetBatchDim(ctx->b_num_axes, *b_shape);
const int64_t b_out_num_dim_differ = out_num_axes - ctx->b_num_axes;
for (int32_t i = 0; i < out_num_axes - 2; i++) {
if (GetOutBatchDim(i) > GetBBatchDim(i)
|| (GetOutBatchDim(i) == 1 && i < b_out_num_dim_differ)) {
b_reduce_vec.push_back(i);
}
}
JUST(VectorAt(*in_grads, 1)) =
JUST(functional::ReduceSumLike(broadcast_grad_b, input_b, b_reduce_vec));
} else {
JUST(VectorAt(*in_grads, 1)) = broadcast_grad_b;
}
} }
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct MatrixVectorProductCaptureState : public AutoGradCaptureState {
bool requires_grad_a = false;
bool requires_grad_b = false;
size_t a_index = 0;
size_t b_index = 1;
};
class MatrixVectorProduct : public OpExprGradFunction<MatrixVectorProductCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(MatrixVectorProductCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const MatrixVectorProductCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
protected:
AttrMap base_attrs_;
};
Maybe<void> MatrixVectorProduct::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be null. ";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> MatrixVectorProduct::Capture(MatrixVectorProductCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad_a = JUST(VectorAt(inputs, 0))->requires_grad();
ctx->requires_grad_b = JUST(VectorAt(inputs, 1))->requires_grad();
if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
if (ctx->requires_grad_a) {
ctx->b_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // input b
}
if (ctx->requires_grad_b) {
ctx->a_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input a
}
return Maybe<void>::Ok();
}
Maybe<void> MatrixVectorProduct::Apply(const MatrixVectorProductCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "Out grad size should be equal to 1. ";
in_grads->resize(2);
if (ctx->requires_grad_a) {
const auto& input_b = JUST(VectorAt(ctx->SavedTensors(), ctx->b_index));
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::MatrixVectorProductGradA(JUST(VectorAt(out_grads, 0)), input_b));
}
if (ctx->requires_grad_b) {
const auto& input_a = JUST(VectorAt(ctx->SavedTensors(), ctx->a_index));
JUST(VectorAt(*in_grads, 1)) =
JUST(functional::MatrixVectorProductGradB(JUST(VectorAt(out_grads, 0)), input_a));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("matrix_vector_product", MatrixVectorProduct);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
namespace {
struct MaxUnpoolCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
size_t input_index = 0;
size_t indices_index = 0;
};
using FuncType = decltype(functional::MaxUnpool1dGrad);
template<FuncType F>
class MaxUnpoolNdGrad : public OpExprGradFunction<MaxUnpoolCaptureState> {
public:
virtual ~MaxUnpoolNdGrad() = default;
using OpExprGradFunction<MaxUnpoolCaptureState>::Init;
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(MaxUnpoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const MaxUnpoolCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
template<FuncType F>
Maybe<void> MaxUnpoolNdGrad<F>::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
template<FuncType F>
Maybe<void> MaxUnpoolNdGrad<F>::Capture(MaxUnpoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));
ctx->indices_index = ctx->SaveTensorForBackward(inputs.at(1));
return Maybe<void>::Ok();
}
template<FuncType F>
Maybe<void> MaxUnpoolNdGrad<F>::Apply(const MaxUnpoolCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_LE_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
const auto& input = ctx->SavedTensors().at(ctx->input_index);
const auto& indices = ctx->SavedTensors().at(ctx->indices_index);
in_grads->resize(2);
(*in_grads)[0] = JUST(F(input, indices, out_grads[0]));
return Maybe<void>::Ok();
}
} // namespace
REGISTER_OP_EXPR_GRAD_FUNCTION("max_unpool_1d", MaxUnpoolNdGrad<functional::MaxUnpool1dGrad>);
REGISTER_OP_EXPR_GRAD_FUNCTION("max_unpool_2d", MaxUnpoolNdGrad<functional::MaxUnpool2dGrad>);
REGISTER_OP_EXPR_GRAD_FUNCTION("max_unpool_3d", MaxUnpoolNdGrad<functional::MaxUnpool3dGrad>);
} // namespace one
} // namespace oneflow
...@@ -92,10 +92,10 @@ class MedianWithIndices : public OpExprGradFunction<MedianWithIndicesCaptureStat ...@@ -92,10 +92,10 @@ class MedianWithIndices : public OpExprGradFunction<MedianWithIndicesCaptureStat
const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0)); const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0));
const auto& indices = JUST(functional::Unsqueeze(JUST(VectorAt(ctx->SavedTensors(), 1)), -1)); const auto& indices = JUST(functional::Unsqueeze(JUST(VectorAt(ctx->SavedTensors(), 1)), -1));
const auto& dout = JUST(functional::Unsqueeze(JUST(VectorAt(out_grads, 0)), -1)); const auto& dout = JUST(functional::Unsqueeze(JUST(VectorAt(out_grads, 0)), -1));
JUST(VectorAt(*in_grads, 0)) = JUST( JUST(VectorAt(*in_grads, 0)) = JUST(functional::DimScatterUpdate(
functional::DimScatter(JUST(functional::Constant(*(input->shape()), Scalar(0), JUST(functional::Constant(*(input->shape()), Scalar(0), *dout->dtype(),
*dout->dtype(), JUST(dout->device()))), JUST(dout->device()))),
-1, indices, dout)); -1, indices, dout, /*inplace*/ false));
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
......
...@@ -71,8 +71,8 @@ class Narrow : public OpExprGradFunction<NarrowCaptureState> { ...@@ -71,8 +71,8 @@ class Narrow : public OpExprGradFunction<NarrowCaptureState> {
functional::Empty(ctx->shape, dy->dtype(), JUST(dy->device()), /*pin_memory=*/false)); functional::Empty(ctx->shape, dy->dtype(), JUST(dy->device()), /*pin_memory=*/false));
} else { } else {
like = JUST( like = JUST(
functional::ConsistentEmpty(ctx->shape, dy->dtype(), JUST(dy->parallel_desc()), functional::GlobalEmpty(ctx->shape, dy->dtype(), JUST(dy->parallel_desc()),
*JUST(private_details::RawGetSbpList(JUST(dy->nd_sbp()))))); *JUST(private_details::RawGetSbpList(JUST(dy->nd_sbp())))));
} }
in_grads->resize(1); in_grads->resize(1);
in_grads->at(0) = JUST(functional::NarrowGrad(dy, like, ctx->dim, ctx->start, ctx->length)); in_grads->at(0) = JUST(functional::NarrowGrad(dy, like, ctx->dim, ctx->start, ctx->length));
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct OneEmbeddingFusedLookupCaptureState : public AutoGradCaptureState {
bool requires_grad{};
std::string embedding_name{};
int64_t line_size{};
int64_t embedding_size{};
int shadow_index{};
int ids_index{};
int input_num{};
};
class OneEmbeddingFusedLookup : public OpExprGradFunction<OneEmbeddingFusedLookupCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(OneEmbeddingFusedLookupCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_GE_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad(); // shadow
ctx->shadow_index = ctx->SaveTensorForBackward(inputs.at(0)); // shadow
ctx->ids_index = ctx->SaveTensorForBackward(inputs.at(1)); // id
ctx->embedding_name = JUST(attrs.GetAttr<std::string>("embedding_name"));
ctx->line_size = JUST(attrs.GetAttr<int64_t>("line_size"));
ctx->embedding_size = JUST(attrs.GetAttr<int64_t>("embedding_size"));
ctx->input_num = inputs.size();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const OneEmbeddingFusedLookupCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(ctx->input_num);
const auto& saved_tensors = ctx->SavedTensors();
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
if (ctx->requires_grad) {
JUST(functional::OneEmbeddingFusedLookupGrad(
saved_tensors.at(ctx->ids_index), JUST(VectorAt(out_grads, 0)), ctx->embedding_name,
ctx->line_size, ctx->embedding_size));
(*in_grads)[0] = JUST(functional::ZerosLike(saved_tensors.at(ctx->shadow_index)));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("one_embedding_fused_lookup", OneEmbeddingFusedLookup);
} // namespace one
} // namespace oneflow
...@@ -20,12 +20,12 @@ limitations under the License. ...@@ -20,12 +20,12 @@ limitations under the License.
namespace oneflow { namespace oneflow {
namespace one { namespace one {
struct Pad2dCaptureState : public AutoGradCaptureState { struct PadNdCaptureState : public AutoGradCaptureState {
bool requires_grad; bool requires_grad = false;
std::vector<int64_t> paddings; std::vector<int64_t> paddings{};
}; };
class Pad2d : public OpExprGradFunction<Pad2dCaptureState> { class PadNd : public OpExprGradFunction<PadNdCaptureState> {
public: public:
Maybe<void> Init(const OpExpr& op) override { Maybe<void> Init(const OpExpr& op) override {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
...@@ -34,7 +34,7 @@ class Pad2d : public OpExprGradFunction<Pad2dCaptureState> { ...@@ -34,7 +34,7 @@ class Pad2d : public OpExprGradFunction<Pad2dCaptureState> {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<void> Capture(Pad2dCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, Maybe<void> Capture(PadNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override { const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
...@@ -50,9 +50,9 @@ class Pad2d : public OpExprGradFunction<Pad2dCaptureState> { ...@@ -50,9 +50,9 @@ class Pad2d : public OpExprGradFunction<Pad2dCaptureState> {
AttrMap base_attrs_; AttrMap base_attrs_;
}; };
class ReflectionPad2d : public Pad2d { class ReflectionPadNd : public PadNd {
public: public:
Maybe<void> Apply(const Pad2dCaptureState* ctx, const TensorTuple& out_grads, Maybe<void> Apply(const PadNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override { TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1); in_grads->resize(1);
...@@ -64,9 +64,9 @@ class ReflectionPad2d : public Pad2d { ...@@ -64,9 +64,9 @@ class ReflectionPad2d : public Pad2d {
} }
}; };
class ReplicationPad2d : public Pad2d { class ReplicationPadNd : public PadNd {
public: public:
Maybe<void> Apply(const Pad2dCaptureState* ctx, const TensorTuple& out_grads, Maybe<void> Apply(const PadNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override { TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1); in_grads->resize(1);
...@@ -121,8 +121,10 @@ class ConstantPadNd : public OpExprGradFunction<ConstantPadNdCaptureState> { ...@@ -121,8 +121,10 @@ class ConstantPadNd : public OpExprGradFunction<ConstantPadNdCaptureState> {
}; };
REGISTER_OP_EXPR_GRAD_FUNCTION("pad", ConstantPadNd); REGISTER_OP_EXPR_GRAD_FUNCTION("pad", ConstantPadNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("reflection_pad2d", ReflectionPad2d); REGISTER_OP_EXPR_GRAD_FUNCTION("reflection_pad1d", ReflectionPadNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("replication_pad2d", ReplicationPad2d); REGISTER_OP_EXPR_GRAD_FUNCTION("reflection_pad2d", ReflectionPadNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("replication_pad1d", ReplicationPadNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("replication_pad2d", ReplicationPadNd);
} // namespace one } // namespace one
} // namespace oneflow } // namespace oneflow
...@@ -64,6 +64,7 @@ Maybe<void> ReduceSum::Apply(const ReduceSumCaptureState* ctx, const TensorTuple ...@@ -64,6 +64,7 @@ Maybe<void> ReduceSum::Apply(const ReduceSumCaptureState* ctx, const TensorTuple
} }
REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_sum", ReduceSum); REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_sum", ReduceSum);
REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_nansum", ReduceSum);
struct ReduceProdOpInterpState : public AutoGradCaptureState { struct ReduceProdOpInterpState : public AutoGradCaptureState {
std::vector<int32_t> axis; std::vector<int32_t> axis;
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
namespace oneflow {
namespace one {
struct ReduceSumLikeCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
std::vector<int32_t> axis;
};
class ReduceSumLike : public OpExprGradFunction<ReduceSumLikeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(ReduceSumLikeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const ReduceSumLikeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> ReduceSumLike::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> ReduceSumLike::Capture(ReduceSumLikeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
CHECK_OR_RETURN(!inputs.at(1)->requires_grad())
<< Error::RuntimeError() << "like tensor does not require grad";
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("axis"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> ReduceSumLike::Apply(const ReduceSumLikeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
const auto& x = ctx->SavedTensors().at(0);
in_grads->resize(2);
in_grads->at(0) = JUST(functional::BroadcastLike(out_grads.at(0), x, ctx->axis));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_sum_like", ReduceSumLike);
} // namespace one
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment