Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct RollCaptureState : public AutoGradCaptureState {
std::vector<int32_t> shifts;
std::vector<int32_t> dims;
bool requires_grad = false;
};
class Roll : public OpExprGradFunction<RollCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(RollCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const RollCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Roll::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Roll::Capture(RollCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->shifts = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("shifts"));
ctx->dims = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dims"));
return Maybe<void>::Ok();
}
Maybe<void> Roll::Apply(const RollCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
std::vector<int32_t> new_shifts;
new_shifts.resize(ctx->shifts.size());
for (int i = 0; i < new_shifts.size(); ++i) { new_shifts[i] = -ctx->shifts[i]; }
in_grads->at(0) = JUST(functional::Roll(out_grads.at(0), new_shifts, ctx->dims));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("roll", Roll);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
struct ScalarAddCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class ScalarAdd : public OpExprGradFunction<ScalarAddCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(ScalarAddCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarAddCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_add", ScalarAdd);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct ScalarDivCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
Scalar operand;
};
class ScalarDiv : public OpExprGradFunction<ScalarDivCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(ScalarDivCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
bool has_float_operand = JUST(composed_attrs.GetAttr<bool>("has_float_operand"));
if (has_float_operand) {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>("float_operand")));
} else {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>("int_operand")));
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarDivCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::ScalarDiv(JUST(VectorAt(out_grads, 0)), ctx->operand));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_div", ScalarDiv);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
// FloorDiv derivatives function isn't exists. (author: zhengzekang)
struct ScalarFloorDivCaptureState : public AutoGradCaptureState {};
class ScalarFloorDiv : public OpExprGradFunction<ScalarFloorDivCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(ScalarFloorDivCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarFloorDivCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
UNIMPLEMENTED_THEN_RETURN() << "RuntimeError: derivative for floor_divide is not implemented";
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_floordiv", ScalarFloorDiv);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ScalarFModGradCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class ScalarFModGrad : public OpExprGradFunction<ScalarFModGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(ScalarFModGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarFModGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_fmod", ScalarFModGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ScalarMulCaptureState : public AutoGradCaptureState {
bool requires_grad;
Scalar operand;
};
class ScalarMul : public OpExprGradFunction<ScalarMulCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(ScalarMulCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
bool has_float_operand = JUST(composed_attrs.GetAttr<bool>("has_float_operand"));
if (has_float_operand) {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>("float_operand")));
} else {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>("int_operand")));
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarMulCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
in_grads->at(0) = JUST(functional::ScalarMul(out_grads.at(0), ctx->operand, false));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_mul", ScalarMul);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ScalarPowCaptureState : public AutoGradCaptureState {
bool requires_grad;
Scalar operand;
};
class ScalarPow : public OpExprGradFunction<ScalarPowCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Capture(ScalarPowCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
bool has_float_operand = JUST(composed_attrs.GetAttr<bool>("has_float_operand"));
if (has_float_operand) {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>("float_operand")));
} else {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>("int_operand")));
}
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarPowCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
MutableAttrMap attrs;
in_grads->resize(1);
if (ctx->requires_grad) {
in_grads->at(0) = JUST(functional::ScalarPowGrad(x, out_grads.at(0), ctx->operand));
}
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_pow", ScalarPow);
class ScalarReversePow : public OpExprGradFunction<ScalarPowCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Capture(ScalarPowCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs[0]->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
bool has_float_operand = JUST(composed_attrs.GetAttr<bool>("has_float_operand"));
if (has_float_operand) {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>("float_operand")));
} else {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>("int_operand")));
}
ctx->SaveTensorForBackward(inputs[0]);
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarPowCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors()[0];
MutableAttrMap attrs;
in_grads->resize(1);
if (ctx->requires_grad) {
(*in_grads)[0] = JUST(functional::ScalarReversePowGrad(x, out_grads[0], ctx->operand));
}
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_reverse_pow", ScalarReversePow);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ScatterNdCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class ScatterNd : public OpExprGradFunction<ScatterNdCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(ScatterNdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(1)->requires_grad();
if (ctx->requires_grad) {
ctx->SaveTensorForBackward(inputs.at(0)); // indices
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScatterNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->requires_grad) {
const auto& indices = ctx->SavedTensors().at(0);
in_grads->at(1) = JUST(functional::GatherNd(out_grads.at(0), indices));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scatter_nd", ScatterNd);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
namespace oneflow {
namespace one {
struct SelectTopNCaptureState : public AutoGradCaptureState {
TensorTuple inputs;
std::vector<bool> requires_grad;
int32_t top_n = 0;
};
class SelectTopN : public OpExprGradFunction<SelectTopNCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(SelectTopNCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->inputs = inputs;
ctx->top_n = JUST(attrs.GetAttr<int32_t>("top_n"));
ctx->requires_grad.resize(inputs.size());
for (int i = 0; i < ctx->requires_grad.size(); ++i) {
ctx->requires_grad.at(i) = inputs.at(i)->requires_grad();
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SelectTopNCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(ctx->top_n, out_grads.size()); // NOLINT(maybe-need-error-msg)
for (int i = 0; i < ctx->top_n; ++i) {
if (!ctx->requires_grad.at(i)) { continue; }
in_grads->at(i) = out_grads.at(i);
}
for (int i = ctx->top_n; i < ctx->inputs.size(); ++i) {
if (!ctx->requires_grad.at(i)) { continue; }
const auto& tensor = ctx->inputs.at(i);
in_grads->at(i) = JUST(StaticZerosTensor::MakeTensor(
tensor->shape(), tensor->dtype()->data_type(), JUST(tensor->device())));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("select_top_n", SelectTopN);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct SliceCaptureState : public AutoGradCaptureState {
Shape like_shape;
std::vector<int64_t> start;
std::vector<int64_t> stop;
std::vector<int64_t> step;
};
class Slice : public OpExprGradFunction<SliceCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(SliceCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->start = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("start"));
ctx->stop = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("stop"));
ctx->step = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("step"));
ctx->like_shape = *(inputs[0]->shape());
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SliceCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(1);
(*in_grads)[0] = JUST(
functional::SliceGrad(out_grads[0], ctx->like_shape, ctx->start, ctx->stop, ctx->step));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
struct SliceUpdateCaptureState : public AutoGradCaptureState {
bool requires_grad_ref = false;
bool requires_grad_value = false;
std::vector<int64_t> start;
std::vector<int64_t> stop;
std::vector<int64_t> step;
Shape value_shape; // used to calculate ref gradient
Symbol<NdSbp> value_sbp;
};
class SliceUpdate : public OpExprGradFunction<SliceUpdateCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(SliceUpdateCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad_ref = inputs[0]->requires_grad();
ctx->requires_grad_value = inputs[1]->requires_grad();
if (!ctx->requires_grad_ref && !ctx->requires_grad_value) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->start = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("start"));
ctx->stop = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("stop"));
ctx->step = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("step"));
if (ctx->requires_grad_ref) {
ctx->value_shape = *(inputs[1]->shape());
if (inputs[1]->is_consistent()) { ctx->value_sbp = JUST(inputs[1]->nd_sbp()); }
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SliceUpdateCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->requires_grad_ref) {
std::shared_ptr<Tensor> zeros;
if (out_grads[0]->is_local()) {
zeros = JUST(functional::Constant(ctx->value_shape, 0, out_grads[0]->dtype(),
JUST(out_grads[0]->device())));
} else {
const auto& parallel_desc = JUST(out_grads[0]->parallel_desc());
zeros =
JUST(functional::ConsistentConstant(ctx->value_shape, 0, out_grads[0]->dtype(),
parallel_desc, *JUST(GetSbpList(ctx->value_sbp))));
}
(*in_grads)[0] = JUST(functional::SliceUpdate(out_grads[0], zeros, ctx->start, ctx->stop,
ctx->step, /*inplace=*/false));
}
if (ctx->requires_grad_value) {
(*in_grads)[1] = JUST(functional::Slice(out_grads[0], ctx->start, ctx->stop, ctx->step,
/*enable_view_slice=*/false));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("slice_update", SliceUpdate);
REGISTER_OP_EXPR_GRAD_FUNCTION("slice", Slice);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct SmoothL1LossCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
float beta = 0.0;
};
class SmoothL1Loss : public OpExprGradFunction<SmoothL1LossCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(SmoothL1LossCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->requires_grad = inputs.at(0)->requires_grad(); // prediction
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->SaveTensorForBackward(inputs.at(0)); // prediction
ctx->SaveTensorForBackward(inputs.at(1)); // label
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->beta = JUST(composed_attrs.GetAttr<float>("beta"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SmoothL1LossCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
const auto& prediction = ctx->SavedTensors().at(0);
const auto& label = ctx->SavedTensors().at(1);
in_grads->at(0) =
JUST(functional::SmoothL1LossGrad(out_grads.at(0), prediction, label, ctx->beta));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("smooth_l1_loss", SmoothL1Loss); // todo: name
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct SoftmaxCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class Softmax : public OpExprGradFunction<SoftmaxCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(SoftmaxCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const SoftmaxCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> Softmax::Init(const OpExpr& op) { return Maybe<void>::Ok(); }
Maybe<void> Softmax::Capture(SoftmaxCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) return Maybe<void>::Ok();
ctx->SaveTensorForBackward(outputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Softmax::Apply(const SoftmaxCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) return Maybe<void>::Ok();
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0);
const auto& y = ctx->SavedTensors().at(0);
in_grads->resize(1);
in_grads->at(0) = JUST(functional::SoftmaxGrad(dy, y));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("softmax", Softmax);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct SoftmaxCrossEntropyGradState : public AutoGradCaptureState {
bool requires_grad = false;
};
class SoftmaxCrossEntropy : public OpExprGradFunction<SoftmaxCrossEntropyGradState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(SoftmaxCrossEntropyGradState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const SoftmaxCrossEntropyGradState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> SoftmaxCrossEntropy::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> SoftmaxCrossEntropy::Capture(SoftmaxCrossEntropyGradState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->SaveTensorForBackward(inputs.at(1)); // label
ctx->SaveTensorForBackward(outputs.at(1)); // prob
return Maybe<void>::Ok();
}
Maybe<void> SoftmaxCrossEntropy::Apply(const SoftmaxCrossEntropyGradState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0);
const auto& label = ctx->SavedTensors().at(0);
const auto& prob = ctx->SavedTensors().at(1);
in_grads->resize(2); // prediction, label
(*in_grads)[0] = JUST(functional::SoftmaxCrossEntropyGrad(dy, label, prob));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("softmax_cross_entropy", SoftmaxCrossEntropy);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct SparseCrossEntropyCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
int64_t depth = -1;
size_t prediction_index = -1;
size_t label_index = -1;
};
template<bool is_distributed>
class SparseCrossEntropy : public OpExprGradFunction<SparseCrossEntropyCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(SparseCrossEntropyCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->depth = JUST(composed_attrs.GetAttr<int64_t>("depth"));
ctx->prediction_index = ctx->SaveTensorForBackward(inputs.at(0)); // prediction
ctx->label_index = ctx->SaveTensorForBackward(inputs.at(1)); // label
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SparseCrossEntropyCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& prediction = ctx->SavedTensors().at(ctx->prediction_index);
const auto& label = ctx->SavedTensors().at(ctx->label_index);
in_grads->resize(2);
if (is_distributed) {
in_grads->at(0) = JUST(
functional::SparseCrossEntropyMsGrad(prediction, label, out_grads.at(0), ctx->depth));
} else {
in_grads->at(0) =
JUST(functional::SparseCrossEntropyGrad(prediction, label, out_grads.at(0), ctx->depth));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_cross_entropy_ms", SparseCrossEntropy<true>);
REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_cross_entropy", SparseCrossEntropy<false>);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct SparseSoftmaxCrossEntropyCaptureState : public AutoGradCaptureState {
int64_t depth;
};
class SparseSoftmaxCrossEntropy : public OpExprGradFunction<SparseSoftmaxCrossEntropyCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> SparseSoftmaxCrossEntropy::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyCaptureState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->depth = JUST(composed_attrs.GetAttr<int64_t>("depth"));
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->SaveTensorForBackward(outputs.at(0)); // prob
ctx->SaveTensorForBackward(inputs.at(1)); // label
return Maybe<void>::Ok();
}
Maybe<void> SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(1);
const auto& prob = ctx->SavedTensors().at(0);
const auto& label = ctx->SavedTensors().at(1);
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("depth", ctx->depth));
// SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not
// require gradient.
in_grads->resize(2);
in_grads->at(0) = JUST(functional::SparseSoftmaxCrossEntropyGrad(dy, prob, label, ctx->depth));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_softmax_cross_entropy", SparseSoftmaxCrossEntropy);
} // namespace one
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct SplitLikeCaptureState : public AutoGradCaptureState {
int64_t axis;
bool requires_grad;
};
class SplitLike : public OpExprGradFunction<SplitLikeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(SplitLikeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const SplitLikeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> SplitLike::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> SplitLike::Capture(SplitLikeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), outputs.size() + 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<int64_t>("axis"));
for (int i = 0; i < outputs.size(); ++i) { ctx->SaveTensorForBackward(outputs.at(i)); }
return Maybe<void>::Ok();
}
Maybe<void> SplitLike::Apply(const SplitLikeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
in_grads->resize(out_grads.size() + 1);
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
const auto& saved_tensors = ctx->SavedTensors();
TensorTuple inputs;
inputs.reserve(out_grads.size());
for (int i = 0; i < out_grads.size(); ++i) {
const auto& out_grad_i = out_grads.at(i);
if (out_grad_i.get()) {
inputs.emplace_back(out_grad_i);
} else {
const auto& zero_grad = JUST(functional::ZerosLike(saved_tensors.at(i)));
inputs.emplace_back(zero_grad);
}
}
in_grads->at(0) = JUST(functional::Concat(inputs, ctx->axis));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("split_like", SplitLike);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct SqueezeCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class Squeeze : public OpExprGradFunction<SqueezeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(SqueezeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const SqueezeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Squeeze::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Squeeze::Capture(SqueezeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Squeeze::Apply(const SqueezeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& like = ctx->SavedTensors().at(0);
in_grads->resize(1);
in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), like));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("squeeze", Squeeze);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct StackCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
int64_t axis = 1;
int64_t input_num = 2;
};
class Stack : public OpExprGradFunction<StackCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(StackCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const StackCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Stack::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Stack::Capture(StackCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad.resize(inputs.size());
for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<int64_t>("axis"));
for (const auto& input : inputs) { ctx->SaveTensorForBackward(input); }
ctx->input_num = inputs.size();
return Maybe<void>::Ok();
}
Maybe<void> Stack::Apply(const StackCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(ctx->input_num);
TensorTuple like(ctx->input_num);
for (int i = 0; i < ctx->input_num; ++i) { like[i] = ctx->SavedTensors().at(i); }
const auto& results = JUST(functional::StackGrad(out_grads.at(0), like, ctx->axis));
CHECK_EQ_OR_RETURN(results->size(), ctx->input_num)
<< Error::RuntimeError() << "The number of results (" << results->size()
<< ") must match the number of inputs (" << ctx->input_num << ")";
for (int i = 0; i < ctx->input_num; ++i) {
if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); }
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("stack", Stack);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct TensorScalarCaptureState : public AutoGradCaptureState {
bool x_requires_grad;
bool scalar_requires_grad;
};
class TensorScalarAddOrSub : public OpExprGradFunction<TensorScalarCaptureState> {
public:
TensorScalarAddOrSub() = default;
virtual ~TensorScalarAddOrSub() = default;
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
};
Maybe<void> TensorScalarAddOrSub::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> TensorScalarAddOrSub::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->scalar_requires_grad = inputs.at(1)->requires_grad();
return Maybe<void>::Ok();
}
class TensorScalarAdd : public TensorScalarAddOrSub {
public:
Maybe<void> Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::Identity(out_grads.at(0))); }
if (ctx->scalar_requires_grad) {
int32_t num_axes = out_grads.at(0)->shape()->NumAxes();
std::vector<int32_t> axes_vec(num_axes);
std::iota(axes_vec.begin(), axes_vec.end(), 0);
in_grads->at(1) = JUST(functional::ReduceSum(out_grads.at(0), axes_vec, false));
}
return Maybe<void>::Ok();
}
};
class TensorScalarSub : public TensorScalarAddOrSub {
public:
Maybe<void> Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
if (ctx->x_requires_grad) { in_grads->at(0) = JUST(functional::Identity(out_grads.at(0))); }
if (ctx->scalar_requires_grad) {
int32_t num_axes = out_grads.at(0)->shape()->NumAxes();
std::vector<int32_t> axes_vec(num_axes);
std::iota(axes_vec.begin(), axes_vec.end(), 0);
const auto& reduce_sum =
JUST(functional::ReduceSum(out_grads.at(0), axes_vec, /*keepdims=*/false));
in_grads->at(1) = JUST(functional::ScalarMul(reduce_sum, /*other=*/1.0, false));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_add_by_tensor", TensorScalarAdd);
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_sub_by_tensor", TensorScalarSub);
class TensorScalarMul : public OpExprGradFunction<TensorScalarCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> TensorScalarMul::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> TensorScalarMul::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->scalar_requires_grad = inputs.at(1)->requires_grad();
if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }
if (ctx->scalar_requires_grad) { ctx->SaveTensorForBackward(inputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> TensorScalarMul::Apply(const TensorScalarCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
in_grads->resize(2);
if (ctx->x_requires_grad) {
const auto& scalar = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::Mul(out_grads.at(0), scalar));
}
if (ctx->scalar_requires_grad) {
const auto& x = ctx->SavedTensors().at(ctx->x_requires_grad);
const auto& y = JUST(functional::Mul(out_grads.at(0), x));
int32_t num_axes = out_grads.at(0)->shape()->NumAxes();
std::vector<int32_t> axes_vec(num_axes);
std::iota(axes_vec.begin(), axes_vec.end(), 0);
in_grads->at(1) = JUST(functional::ReduceSum(y, axes_vec, /*keepdims=*/false));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_mul_by_tensor", TensorScalarMul);
class TensorScalarDiv : public OpExprGradFunction<TensorScalarCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
std::shared_ptr<OpExpr> tensor_scalar_div_op_;
std::shared_ptr<OpExpr> broadcast_div_grad_op_;
};
Maybe<void> TensorScalarDiv::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> TensorScalarDiv::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->scalar_requires_grad = inputs.at(1)->requires_grad();
if (ctx->x_requires_grad || ctx->scalar_requires_grad) {
ctx->SaveTensorForBackward(inputs.at(1));
}
if (ctx->scalar_requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> TensorScalarDiv::Apply(const TensorScalarCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
in_grads->resize(2);
if (ctx->x_requires_grad) {
const auto& scalar = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::Div(out_grads.at(0), scalar));
}
if (ctx->scalar_requires_grad) {
const auto& scalar = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
in_grads->at(1) = JUST(functional::DivGrad(out_grads.at(0), y, scalar));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_div_by_tensor", TensorScalarDiv);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct TensorScatterNdUpdateCaptureState : public AutoGradCaptureState {
bool tensor_requires_grad = false;
bool update_requires_grad = false;
};
class TensorScatterNdUpdate : public OpExprGradFunction<TensorScatterNdUpdateCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(TensorScatterNdUpdateCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->tensor_requires_grad = inputs.at(0)->requires_grad();
ctx->update_requires_grad = inputs.at(2)->requires_grad();
if (ctx->update_requires_grad || ctx->tensor_requires_grad) {
ctx->SaveTensorForBackward(inputs.at(1)); // indices
}
if (ctx->tensor_requires_grad) {
ctx->SaveTensorForBackward(inputs.at(2)); // update: only use meta information
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const TensorScatterNdUpdateCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(3);
if (ctx->update_requires_grad) {
const auto& indices = ctx->SavedTensors().at(0);
in_grads->at(2) = JUST(functional::GatherNd(out_grads.at(0), indices));
}
if (ctx->tensor_requires_grad) {
const auto& indices = ctx->SavedTensors().at(0);
const auto& update = ctx->SavedTensors().at(1);
const auto& temp = JUST(functional::ZerosLike(update));
in_grads->at(0) = JUST(
functional::TensorScatterNdUpdate(out_grads.at(0), indices, temp, /*inplace=*/false));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("tensor_scatter_nd_update", TensorScatterNdUpdate);
} // namespace one
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment