Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct CombinedMarginLossCaptureState : public AutoGradCaptureState {
float m1;
float m2;
float m3;
int64_t depth;
size_t label_index;
size_t theta_index;
bool requires_grad;
};
class CombinedMarginLoss : public OpExprGradFunction<CombinedMarginLossCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(CombinedMarginLossCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad(); // x
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->label_index = ctx->SaveTensorForBackward(inputs.at(1)); // label
ctx->theta_index = ctx->SaveTensorForBackward(outputs.at(1)); // theta
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->m1 = JUST(composed_attrs.GetAttr<float>("m1"));
ctx->m2 = JUST(composed_attrs.GetAttr<float>("m2"));
ctx->m3 = JUST(composed_attrs.GetAttr<float>("m3"));
ctx->depth = JUST(composed_attrs.GetAttr<int64_t>("depth"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CombinedMarginLossCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->requires_grad) {
const auto& label = ctx->SavedTensors().at(ctx->label_index);
const auto& theta = ctx->SavedTensors().at(ctx->theta_index);
in_grads->at(0) = JUST(functional::CombinedMarginLossGrad(
out_grads.at(0), label, theta, ctx->m1, ctx->m2, ctx->m3, ctx->depth));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("combined_margin_loss", CombinedMarginLoss);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ConcatCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
int64_t axis;
int64_t input_num;
};
class Concat : public OpExprGradFunction<ConcatCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(ConcatCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const ConcatCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Concat::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Concat::Capture(ConcatCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad.resize(inputs.size());
for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<int64_t>("axis"));
for (const auto& input : inputs) { ctx->SaveTensorForBackward(input); }
ctx->input_num = inputs.size();
return Maybe<void>::Ok();
}
Maybe<void> Concat::Apply(const ConcatCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(ctx->input_num);
TensorTuple like(ctx->input_num);
for (int i = 0; i < ctx->input_num; ++i) { like[i] = ctx->SavedTensors().at(i); }
if (ctx->input_num == 1) {
in_grads->at(0) = out_grads.at(0);
} else {
const auto& results = JUST(functional::SplitLike(out_grads.at(0), like, ctx->axis));
CHECK_EQ_OR_RETURN(results->size(), ctx->input_num)
<< Error::RuntimeError() << "The size of results (" << results->size()
<< ") must match the size of inputs (" << ctx->input_num << ")";
for (int i = 0; i < ctx->input_num; ++i)
if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); }
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("concat", Concat);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct CastConsistentCaptureState : public AutoGradCaptureState {
Symbol<ParallelDesc> parallel_desc;
Symbol<NdSbp> nd_sbp;
std::shared_ptr<const Shape> shape;
Symbol<DType> dtype;
};
class CastToConsistent : public OpExprGradFunction<CastConsistentCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const CastToConsistentOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
const std::string& op_name = fw_op_expr->op_name();
grad_op_ = JUST(one::CastFromConsistentOpExpr::New(GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> Capture(CastConsistentCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
ctx->parallel_desc = JUST(interp_ctx.parallel_desc);
ctx->nd_sbp = JUST(GetDualNdSbp(JUST(interp_ctx.nd_sbp)));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CastConsistentCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
std::shared_ptr<Tensor> out_grad = out_grads.at(0);
CHECK_OR_RETURN(out_grad->is_consistent())
<< Error::RuntimeError()
<< "Expected global tensor for cast_to_consistent but got local tensor";
{
Symbol<NdSbp> nd_sbp_constraint = ctx->nd_sbp;
Symbol<ParallelDesc> parallel_desc_constraint = ctx->parallel_desc;
out_grad = JUST(functional::ToConsistent(out_grad, parallel_desc_constraint,
*JUST(GetSbpList(nd_sbp_constraint)),
GetNoneSbpList(), /* check_meta */ false));
}
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grad}));
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("cast_to_consistent", CastToConsistent);
class CastFromConsistent : public OpExprGradFunction<CastConsistentCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const CastFromConsistentOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
const std::string& op_name = fw_op_expr->op_name();
grad_op_ = JUST(one::CastToConsistentOpExpr::New(GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> Capture(CastConsistentCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
const auto& input = inputs.at(0);
CHECK_OR_RETURN(input->is_consistent())
<< Error::RuntimeError()
<< "Expected global tensor for cast_from_consistent but got local tensor";
ctx->parallel_desc = JUST(input->parallel_desc());
ctx->nd_sbp = JUST(input->nd_sbp());
ctx->shape = input->shape();
ctx->dtype = input->dtype();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CastConsistentCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& dual_nd_sbp = JUST(GetDualNdSbp(ctx->nd_sbp));
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", *ctx->shape));
JUST(attrs.SetAttr<DataType>("dtype", ctx->dtype->data_type()));
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(
*grad_op_, {out_grads.at(0)}, OpExprInterpContext(attrs, ctx->parallel_desc, dual_nd_sbp)));
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("cast_from_consistent", CastFromConsistent);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/optional.h"
namespace oneflow {
namespace one {
struct ConsistentToConsistentState : public AutoGradCaptureState {
Symbol<ParallelDesc> parallel_desc;
Symbol<NdSbp> nd_sbp;
};
class ConsistentToConsistentGradFunction : public OpExprGradFunction<ConsistentToConsistentState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const ConsistentToConsistentOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
grad_nd_sbp_ = fw_op_expr->grad_nd_sbp();
return Maybe<void>::Ok();
}
Maybe<void> Capture(ConsistentToConsistentState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->parallel_desc = JUST(inputs.at(0)->parallel_desc());
ctx->nd_sbp = JUST(inputs.at(0)->nd_sbp());
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ConsistentToConsistentState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& out_grad = out_grads.at(0);
CHECK_OR_RETURN(out_grad->is_consistent())
<< Error::RuntimeError()
<< "Expected global tensor for consistent_to_consistent but got local tensor";
in_grads->resize(1);
const auto& grad_nd_sbp = grad_nd_sbp_.value_or(JUST(out_grad->nd_sbp()));
const auto& grad_sbp_list = JUST(GetSbpList(grad_nd_sbp));
const auto& grad_grad_sbp_list = JUST(GetSbpList(ctx->nd_sbp));
(*in_grads)[0] = JUST(one::functional::ToConsistent(
out_grad, ctx->parallel_desc, *grad_sbp_list, *grad_grad_sbp_list, /* check_meta */ false));
return Maybe<void>::Ok();
}
private:
Optional<Symbol<NdSbp>> grad_nd_sbp_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("consistent_to_consistent", ConsistentToConsistentGradFunction);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ConvolutionNdCaptureState : public AutoGradCaptureState {
bool input_requires_grad = false;
bool weight_requires_grad = false;
size_t input_index;
size_t weight_index;
std::string data_format;
std::vector<int32_t> padding_before;
std::vector<int32_t> kernel_size;
std::vector<int32_t> strides;
std::vector<int32_t> dilation_rate;
int32_t groups;
};
class ConvolutionNd : public OpExprGradFunction<ConvolutionNdCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> ConvolutionNd::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->weight_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->input_requires_grad && !ctx->weight_requires_grad) { return Maybe<void>::Ok(); }
if (ctx->input_requires_grad) {
ctx->weight_index = ctx->SaveTensorForBackward(inputs.at(1)); // weight
}
ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); // input
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before"));
ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("strides"));
ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation_rate"));
ctx->groups = JUST(composed_attrs.GetAttr<int32_t>("groups"));
return Maybe<void>::Ok();
}
Maybe<void> ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
in_grads->resize(2);
size_t num_spatial_dims = ctx->kernel_size.size();
if (ctx->input_requires_grad) {
const auto& weight = ctx->SavedTensors().at(ctx->weight_index);
const auto& input = ctx->SavedTensors().at(ctx->input_index);
in_grads->at(0) = JUST(functional::ConvDataGrad(
out_grads.at(0), weight, input, num_spatial_dims, ctx->kernel_size, ctx->strides,
ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));
}
if (ctx->weight_requires_grad) {
const auto& input = ctx->SavedTensors().at(ctx->input_index);
in_grads->at(1) = JUST(functional::ConvFilterGrad(
out_grads.at(0), input, num_spatial_dims, ctx->kernel_size, ctx->strides,
ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("conv1d", ConvolutionNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("conv2d", ConvolutionNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("conv3d", ConvolutionNd);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct CopyCaptureState : public AutoGradCaptureState {
std::string device_type;
int64_t device_id;
};
class Copy : public OpExprGradFunction<CopyCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Capture(CopyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
ctx->device_type = JUST(inputs.at(0)->device())->type();
ctx->device_id = JUST(inputs.at(0)->device())->device_id();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CopyCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(1);
(*in_grads)[0] = JUST(
functional::Copy(out_grads[0], ctx->device_type, ctx->device_id, /*pin_memory=*/false));
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("copy", Copy);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct CTCLossCaptureState : public AutoGradCaptureState {
int64_t max_target_length;
int32_t blank;
bool zero_infinity;
bool requires_grad;
};
class CTCLoss : public OpExprGradFunction<CTCLossCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const CTCLossCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
std::shared_ptr<OpExpr> grad_op_;
};
Maybe<void> CTCLoss::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> CTCLoss::Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->max_target_length = JUST(composed_attrs.GetAttr<int64_t>("max_target_length"));
ctx->blank = JUST(composed_attrs.GetAttr<int32_t>("blank"));
ctx->zero_infinity = JUST(composed_attrs.GetAttr<bool>("zero_infinity"));
CHECK_EQ_OR_RETURN(inputs.size(), 4); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->SaveTensorForBackward(outputs.at(0)); // loss
ctx->SaveTensorForBackward(outputs.at(1)); // alpha
ctx->SaveTensorForBackward(inputs.at(0)); // log_probs
ctx->SaveTensorForBackward(inputs.at(1)); // targets
ctx->SaveTensorForBackward(inputs.at(2)); // input_lengths
ctx->SaveTensorForBackward(inputs.at(3)); // target_lengths
return Maybe<void>::Ok();
}
Maybe<void> CTCLoss::Apply(const CTCLossCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
const auto& grad_out = out_grads.at(0);
const auto& loss = ctx->SavedTensors().at(0);
const auto& alpha = ctx->SavedTensors().at(1);
const auto& log_probs = ctx->SavedTensors().at(2);
const auto& targets = ctx->SavedTensors().at(3);
const auto& input_lengths = ctx->SavedTensors().at(4);
const auto& target_lengths = ctx->SavedTensors().at(5);
in_grads->resize(4);
in_grads->at(0) = JUST(functional::CtcLossGrad(grad_out, log_probs, targets, input_lengths,
target_lengths, loss, alpha, ctx->blank,
ctx->zero_infinity, ctx->max_target_length));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("ctc_loss", CTCLoss);
} // namespace one
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/error.pb.h"
#include "oneflow/core/common/just.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#if CUDA_VERSION >= 11060
namespace oneflow {
namespace one {
struct CublasFusedMLPCaptureState : public AutoGradCaptureState {
int32_t weight_num = 0;
bool skip_final_activation = false;
bool x_requires_grad = false;
std::vector<bool> weights_requires_grad;
std::vector<bool> biases_requires_grad;
};
class CublasFusedMLP : public OpExprGradFunction<CublasFusedMLPCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(CublasFusedMLPCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const CublasFusedMLPCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
protected:
AttrMap base_attrs_;
};
Maybe<void> CublasFusedMLP::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> CublasFusedMLP::Capture(CublasFusedMLPCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_OR_RETURN(inputs.size() % 2 == 1)
<< Error::RuntimeError() << "Both weight and bias should be passed together";
int32_t weight_num = (inputs.size() - 1) / 2;
ctx->weight_num = weight_num;
ctx->x_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
ctx->weights_requires_grad.resize(weight_num);
ctx->biases_requires_grad.resize(weight_num);
for (int32_t i = 0; i < weight_num; i++) {
ctx->weights_requires_grad.at(i) = inputs.at(i + 1)->requires_grad(); // NOLINT
ctx->biases_requires_grad.at(i) = inputs.at(i + 1 + weight_num)->requires_grad(); // NOLINT
}
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // x. idx_sum:1
for (int32_t i = 0; i < weight_num; i++) {
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, i + 1))); // weights. idx_sum:1+w
}
ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); // final layers output. idx_sum:2+w
for (int32_t i = 0; i < weight_num; i++) {
ctx->SaveTensorForBackward(
JUST(VectorAt(outputs, i + 1))); // cublas aux. need minus 1. idx_sum:2+2w
}
for (int32_t i = 0; i < weight_num - 1; i++) {
ctx->SaveTensorForBackward(JUST(VectorAt(outputs, i + 1 + weight_num))); // hidden.
}
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->skip_final_activation = JUST(composed_attrs.GetAttr<bool>("skip_final_activation"));
return Maybe<void>::Ok();
}
Maybe<void> CublasFusedMLP::Apply(const CublasFusedMLPCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
int32_t weight_num = ctx->weight_num;
in_grads->resize(1 + 2 * weight_num);
std::shared_ptr<one::Tensor> last_bias_dy = JUST(VectorAt(out_grads, 0));
if (!ctx->skip_final_activation) {
// step1: use dy and final output to get last layer's relu grad.
last_bias_dy = JUST(functional::ReluGrad(JUST(VectorAt(out_grads, 0)),
JUST(VectorAt(ctx->SavedTensors(), 1 + weight_num))));
}
// step2: use reduce_sum to get last layer's bias grad.
std::vector<int32_t> reduce_axes_vec{0};
if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) {
JUST(VectorAt(*in_grads, 2 * weight_num)) =
JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false));
}
TensorTuple hiddens(weight_num - 1);
TensorTuple weights(weight_num);
TensorTuple cublas_auxs(weight_num);
TensorTuple dgrad(weight_num);
std::shared_ptr<one::Tensor> x = JUST(VectorAt(ctx->SavedTensors(), 0));
for (int32_t i = 0; i < weight_num; ++i) {
weights[i] = JUST(VectorAt(ctx->SavedTensors(), 1 + i));
}
for (int32_t i = 0; i < weight_num; ++i) {
cublas_auxs[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + weight_num));
}
for (int32_t i = 0; i < weight_num - 1; ++i) {
hiddens[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + 2 * weight_num));
}
std::shared_ptr<one::Tensor> cublas_dy = last_bias_dy;
for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) {
// If it is final layer, we use out_grads[0] as dy.
if (hidden_layer_idx != weight_num - 1) {
cublas_dy = JUST(VectorAt(dgrad, hidden_layer_idx + 1));
}
/*
Here we use cublas to compute bias + relu + matmul grad.
Then use Matmul to compute weight grad.
*/
const auto& matmul_relu_bias_bgrad = JUST(functional::CublasBiasAddReluMatmulGrad(
cublas_dy, JUST(VectorAt(weights, hidden_layer_idx)),
JUST(VectorAt(cublas_auxs, hidden_layer_idx - 1)), /*alpha=*/1.0));
// dgrad
dgrad.at(hidden_layer_idx) = matmul_relu_bias_bgrad->at(0); // NOLINT
if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx - 1)))) {
// dbias
JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx)) =
matmul_relu_bias_bgrad->at(1); // NOLINT
}
// dw
if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {
JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = JUST(functional::MatMul(
cublas_dy, JUST(VectorAt(hiddens, hidden_layer_idx - 1)), true, false, 1.0));
}
}
// For the first layer, we need to use 2 matmul to get grads.
std::shared_ptr<one::Tensor> last_dy;
if (weight_num != 1) {
last_dy = JUST(VectorAt(dgrad, 1));
} else {
last_dy = last_bias_dy;
}
if (ctx->x_requires_grad) {
// dx:
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::MatMul(last_dy, JUST(VectorAt(weights, 0)), false, false, 1.0));
}
if (JUST(VectorAt(ctx->weights_requires_grad, 0))) {
// dw:
JUST(VectorAt(*in_grads, 1)) =
JUST(functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("cublas_fused_mlp", CublasFusedMLP);
} // namespace one
} // namespace oneflow
#endif // CUDA_VERSION >= 11060
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct CumCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
int32_t dim = 0;
};
template<typename StateT>
class CumGrad : public OpExprGradFunction<StateT> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
protected:
AttrMap base_attrs_;
};
class CumsumGrad : public CumGrad<CumCaptureState> {
public:
Maybe<void> Capture(CumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int64_t>("dim"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CumCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
std::vector<int32_t> flip_dim(1, ctx->dim);
(*in_grads)[0] = JUST(
functional::Flip(JUST(functional::Cumsum(JUST(functional::Flip(out_grads[0], flip_dim)),
ctx->dim, out_grads[0]->dtype())),
flip_dim));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("cumsum", CumsumGrad);
class CumProdGrad : public CumGrad<CumCaptureState> {
public:
Maybe<void> Capture(CumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int64_t>("dim"));
ctx->SaveTensorForBackward(outputs.at(0));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CumCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
in_grads->at(0) = JUST(functional::CumprodGrad(out_grads.at(0), ctx->SavedTensors().at(0),
ctx->SavedTensors().at(1), ctx->dim));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("cumprod", CumProdGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cstdint>
#include "oneflow/core/common/optional.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct DeConvolutionNdCaptureState : public AutoGradCaptureState {
bool weight_requires_grad = false;
bool activation_requires_grad = false;
size_t ndims;
std::string data_format;
std::vector<int32_t> padding_before;
std::vector<int32_t> kernel_size;
std::vector<int32_t> strides;
std::vector<int32_t> dilation_rate;
int32_t groups;
};
class DeConvolutionNd : public OpExprGradFunction<DeConvolutionNdCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DeConvolutionNdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DeConvolutionNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> DeConvolutionNd::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> DeConvolutionNd::Capture(DeConvolutionNdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->activation_requires_grad = inputs.at(0)->requires_grad();
ctx->weight_requires_grad = inputs.at(1)->requires_grad();
if (ctx->activation_requires_grad) {
ctx->SaveTensorForBackward(inputs.at(1)); // weight
}
if (ctx->weight_requires_grad) {
ctx->SaveTensorForBackward(inputs.at(0)); // x
}
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before"));
ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("strides"));
ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation_rate"));
ctx->groups = JUST(composed_attrs.GetAttr<int32_t>("groups"));
ctx->ndims = ctx->kernel_size.size();
return Maybe<void>::Ok();
}
Maybe<void> DeConvolutionNd::Apply(const DeConvolutionNdCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
in_grads->resize(2);
if (ctx->activation_requires_grad) {
const auto& x = ctx->SavedTensors().at(1);
std::vector<int64_t> start, stop, step;
for (int i = 0; i < x->shape()->NumAxes(); i++) {
start.emplace_back(0);
stop.emplace_back(x->shape()->At(i));
step.emplace_back(1);
}
const auto& weight = ctx->SavedTensors().at(0);
if (ctx->ndims == 1) {
std::shared_ptr<Tensor> result = JUST(functional::Conv1d(
out_grads.at(0), weight, Optional<Tensor>(), ctx->strides, ctx->padding_before,
ctx->dilation_rate, ctx->groups, ctx->data_format));
result = JUST(functional::Slice(result, start, stop, step, /*enable_view_slice=*/false));
in_grads->at(0) = result;
} else if (ctx->ndims == 2) {
std::shared_ptr<Tensor> result = JUST(functional::Conv2d(
out_grads.at(0), weight, Optional<Tensor>(), ctx->strides, ctx->padding_before,
ctx->dilation_rate, ctx->groups, ctx->data_format));
result = JUST(functional::Slice(result, start, stop, step, /*enable_view_slice=*/false));
in_grads->at(0) = result;
} else if (ctx->ndims == 3) {
std::shared_ptr<Tensor> result = JUST(functional::Conv3d(
out_grads.at(0), weight, Optional<Tensor>(), ctx->strides, ctx->padding_before,
ctx->dilation_rate, ctx->groups, ctx->data_format));
result = JUST(functional::Slice(result, start, stop, step, /*enable_view_slice=*/false));
in_grads->at(0) = result;
} else {
UNIMPLEMENTED_THEN_RETURN() << "Invalid ndim " << ctx->ndims << " for conv functor";
}
}
if (ctx->weight_requires_grad) {
int idx = ctx->activation_requires_grad;
const auto& x = ctx->SavedTensors().at(idx);
in_grads->at(1) = JUST(functional::ConvFilterGrad(
x, out_grads.at(0), ctx->ndims, ctx->kernel_size, ctx->strides, ctx->padding_before,
ctx->dilation_rate, ctx->groups, ctx->data_format));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("deconv1d", DeConvolutionNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("deconv2d", DeConvolutionNd);
REGISTER_OP_EXPR_GRAD_FUNCTION("deconv3d", DeConvolutionNd);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct DiagCaptureState : public AutoGradCaptureState {
bool requires_grad;
int32_t diagonal;
};
class Diag : public OpExprGradFunction<DiagCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DiagCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const DiagCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Diag::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Diag::Capture(DiagCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->diagonal = JUST(composed_attrs.GetAttr<int32_t>("diagonal"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Diag::Apply(const DiagCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::DiagGrad(out_grads.at(0), x, ctx->diagonal));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("diag", Diag);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct DiagonalInterpState : public AutoGradCaptureState {
bool requires_grad = false;
int32_t offset = 0;
};
class Diagonal : public OpExprGradFunction<DiagonalInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DiagonalInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DiagonalInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Diagonal::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Diagonal::Capture(DiagonalInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->offset = JUST(composed_attrs.GetAttr<int32_t>("offset"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Diagonal::Apply(const DiagonalInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::DiagonalGrad(out_grads.at(0), x, ctx->offset));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("diagonal", Diagonal);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct DimGatherCaptureState : public AutoGradCaptureState {
int32_t dim;
bool requires_grad;
};
class DimGather : public OpExprGradFunction<DimGatherCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DimGatherCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DimGatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
std::shared_ptr<OpExpr> bw_dim_gather_op_;
};
Maybe<void> DimGather::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> DimGather::Capture(DimGatherCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(1));
ctx->SaveTensorForBackward(inputs.at(0));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int32_t>("dim"));
return Maybe<void>::Ok();
}
Maybe<void> DimGather::Apply(const DimGatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& like = ctx->SavedTensors().at(1);
in_grads->at(0) = JUST(functional::DimScatterAddLike(like, ctx->dim, index, out_grads.at(0)));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_gather", DimGather);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct DimScatterCaptureState : public AutoGradCaptureState {
int32_t dim;
bool input_requires_grad;
bool src_requires_grad;
};
enum SCATTER_TYPE { SCATTER_UPDATE, SCATTER_ADD };
template<SCATTER_TYPE T>
class DimScatter : public OpExprGradFunction<DimScatterCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
Maybe<void> ApplyCommon(const DimScatterCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const;
private:
AttrMap base_attrs_;
};
template<SCATTER_TYPE T>
Maybe<void> DimScatter<T>::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
template<SCATTER_TYPE T>
Maybe<void> DimScatter<T>::Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->src_requires_grad = inputs.at(2)->requires_grad();
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(1)); // index saved
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int32_t>("dim"));
return Maybe<void>::Ok();
}
template<SCATTER_TYPE T>
Maybe<void> DimScatter<T>::ApplyCommon(const DimScatterCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
in_grads->resize(3);
if (ctx->src_requires_grad) {
in_grads->at(2) = JUST(functional::DimGather(out_grads.at(0), ctx->dim, index, false));
}
return Maybe<void>::Ok();
}
template<>
Maybe<void> DimScatter<SCATTER_TYPE::SCATTER_UPDATE>::Apply(const DimScatterCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
JUST(ApplyCommon(ctx, out_grads, in_grads));
if (ctx->input_requires_grad) {
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
in_grads->at(0) =
JUST(functional::DimScatterUpdateScalar(out_grads.at(0), ctx->dim, index, 0.0f));
}
return Maybe<void>::Ok();
}
template<>
Maybe<void> DimScatter<SCATTER_TYPE::SCATTER_ADD>::Apply(const DimScatterCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
JUST(ApplyCommon(ctx, out_grads, in_grads));
if (ctx->input_requires_grad) { in_grads->at(0) = out_grads.at(0); }
return Maybe<void>::Ok();
}
class DimScatterUpdateScalar : public OpExprGradFunction<DimScatterCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> DimScatterUpdateScalar::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> DimScatterUpdateScalar::Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs.at(0)->requires_grad();
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(1)); // index saved
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int32_t>("dim"));
return Maybe<void>::Ok();
}
Maybe<void> DimScatterUpdateScalar::Apply(const DimScatterCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
in_grads->resize(2);
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", ctx->dim));
JUST(attrs.SetAttr<float>("src_scalar", 0.0f));
in_grads->at(0) =
JUST(functional::DimScatterUpdateScalar(out_grads.at(0), ctx->dim, index, 0.0f));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update", DimScatter<SCATTER_TYPE::SCATTER_UPDATE>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_add", DimScatter<SCATTER_TYPE::SCATTER_ADD>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update_scalar", DimScatterUpdateScalar);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct DotCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool y_requires_grad = false;
size_t x_offset = 0;
size_t y_offset = 0;
};
class DotGrad : public OpExprGradFunction<DotCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(DotCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->x_requires_grad = inputs.at(0)->requires_grad();
if (ctx->x_requires_grad) { ctx->x_offset = ctx->SaveTensorForBackward(inputs.at(1)); }
ctx->y_requires_grad = inputs.at(1)->requires_grad();
if (ctx->y_requires_grad) { ctx->y_offset = ctx->SaveTensorForBackward(inputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const DotCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(2);
if (ctx->x_requires_grad) {
const auto& x = ctx->SavedTensors().at(ctx->x_offset);
const auto& results = JUST(functional::Mul(x, out_grads.at(0)));
in_grads->at(0) = results;
}
if (ctx->y_requires_grad) {
const auto& y = ctx->SavedTensors().at(ctx->y_offset);
const auto& results = JUST(functional::Mul(y, out_grads.at(0)));
in_grads->at(1) = results;
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("dot", DotGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct DropoutCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
bool has_addend = false;
float rate = 0.0;
};
class Dropout : public OpExprGradFunction<DropoutCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DropoutCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DropoutCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Dropout::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Dropout::Capture(DropoutCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->rate = JUST(composed_attrs.GetAttr<float>("rate"));
if (inputs.size() == 1) {
ctx->has_addend = false;
} else if (inputs.size() == 2) {
ctx->has_addend = true;
} else {
UNIMPLEMENTED();
}
ctx->SaveTensorForBackward(outputs.at(1)); // output mask
return Maybe<void>::Ok();
}
Maybe<void> Dropout::Apply(const DropoutCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // Output has y and mask.
float scale = 0.0f; // When dropout rate = 1.0, we set scale as zero.
if (ctx->rate < 1.0f) { scale = 1.0f / (1.0f - ctx->rate); }
const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(0);
if (ctx->has_addend) {
in_grads->resize(2);
in_grads->at(0) = JUST(functional::DropoutGrad(out_grads.at(0), mask, scale));
in_grads->at(1) = out_grads.at(0);
return Maybe<void>::Ok();
} else {
in_grads->resize(1);
in_grads->at(0) = JUST(functional::DropoutGrad(out_grads.at(0), mask, scale));
return Maybe<void>::Ok();
}
}
REGISTER_OP_EXPR_GRAD_FUNCTION("dropout", Dropout);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
namespace oneflow {
namespace one {
namespace {
Maybe<one::UserOpExpr> EagerNcclReduce(Symbol<ParallelDesc> parallel_desc, int64_t root) {
return one::OpBuilder("eager_nccl_reduce", *JUST(UniqueStr("eager_nccl_reduce")))
.Input("in")
.Output("out")
.Attr<std::string>("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf()))
.Attr<int64_t>("root", root)
.Build();
}
Maybe<one::UserOpExpr> FindOrCreatEagerNcclReduceOpExpr(Symbol<ParallelDesc> parallel_desc,
int64_t root) {
thread_local HashMap<std::pair<Symbol<ParallelDesc>, int64_t>, std::shared_ptr<one::UserOpExpr>>
parallel_desc_and_root_device2eager_nccl_reduce;
const auto& key = std::make_pair(parallel_desc, root);
auto iter = parallel_desc_and_root_device2eager_nccl_reduce.find(key);
if (iter == parallel_desc_and_root_device2eager_nccl_reduce.end()) {
std::shared_ptr<UserOpExpr> op_expr = JUST(EagerNcclReduce(parallel_desc, root));
iter = parallel_desc_and_root_device2eager_nccl_reduce.emplace(key, op_expr).first;
}
return iter->second;
}
} // namespace
struct EagerNcclBroadcastCaptureState : public AutoGradCaptureState {
Symbol<ParallelDesc> parallel_desc;
int64_t root;
};
class EagerNcclBroadcast : public OpExprGradFunction<EagerNcclBroadcastCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
return Maybe<void>::Ok();
}
Maybe<void> Capture(EagerNcclBroadcastCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
ctx->root = JUST(interp_ctx.attrs.GetAttr<int64_t>("root"));
ctx->parallel_desc = JUST(interp_ctx.parallel_desc);
return Maybe<void>::Ok();
}
Maybe<void> Apply(const EagerNcclBroadcastCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& grad_op = JUST(FindOrCreatEagerNcclReduceOpExpr(ctx->parallel_desc, ctx->root));
in_grads->resize(1);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op, {out_grads.at(0)}));
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("eager_nccl_broadcast", EagerNcclBroadcast);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ElementwiseXimumCaptureState : public AutoGradCaptureState {
bool x_requires_grad;
bool y_requires_grad;
};
class ElementwiseXimumOp : public OpExprGradFunction<ElementwiseXimumCaptureState> {
public:
Maybe<void> Capture(ElementwiseXimumCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->y_requires_grad = inputs.at(1)->requires_grad();
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(inputs.at(1));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ElementwiseXimumCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!(ctx->x_requires_grad || ctx->y_requires_grad)) { return Maybe<void>::Ok(); }
in_grads->resize(2);
const std::shared_ptr<one::Tensor>& x = ctx->SavedTensors().at(0);
const std::shared_ptr<one::Tensor>& y = ctx->SavedTensors().at(1);
if (ctx->x_requires_grad || ctx->y_requires_grad) {
const auto& grads = JUST(grad_functor(out_grads.at(0), x, y));
if (ctx->x_requires_grad) { in_grads->at(0) = grads->at(0); }
if (ctx->y_requires_grad) { in_grads->at(1) = grads->at(1); }
}
return Maybe<void>::Ok();
}
protected:
std::function<Maybe<TensorTuple>(const std::shared_ptr<Tensor>&, const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)>
grad_functor;
};
class ElementwiseMinimum : public ElementwiseXimumOp {
public:
Maybe<void> Init(const OpExpr& op) override {
grad_functor = functional::ElementwiseMinGrad;
return Maybe<void>::Ok();
}
};
class ElementwiseMaximum : public ElementwiseXimumOp {
public:
Maybe<void> Init(const OpExpr& op) override {
grad_functor = functional::ElementwiseMaxGrad;
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("elementwise_minimum", ElementwiseMinimum);
REGISTER_OP_EXPR_GRAD_FUNCTION("elementwise_maximum", ElementwiseMaximum);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct EmbeddingCaptureState : public AutoGradCaptureState {
int64_t padding_idx = -1;
bool scale_grad_by_freq = false;
bool requires_grad = false;
};
class Embedding : public OpExprGradFunction<EmbeddingCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(EmbeddingCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const EmbeddingCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Embedding::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "Forward op must be not null";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Embedding::Capture(EmbeddingCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));
ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1)));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->padding_idx = JUST(composed_attrs.GetAttr<int64_t>("padding_idx"));
ctx->scale_grad_by_freq = JUST(composed_attrs.GetAttr<bool>("scale_grad_by_freq"));
return Maybe<void>::Ok();
}
Maybe<void> Embedding::Apply(const EmbeddingCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
in_grads->resize(ctx->SavedTensors().size());
const auto& weight = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));
const auto& indices = JUST(oneflow::VectorAt(ctx->SavedTensors(), 1));
int64_t padding_idx = ctx->padding_idx;
bool scale_grad_by_freq = ctx->scale_grad_by_freq;
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::EmbeddingGrad(
JUST(oneflow::VectorAt(out_grads, 0)), weight, indices, padding_idx, scale_grad_by_freq));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("embedding", Embedding);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ExpandCaptureState : public AutoGradCaptureState {
std::vector<int32_t> logical_out_shape;
std::vector<int32_t> logical_expand_shape;
bool requires_grad;
};
class Expand : public OpExprGradFunction<ExpandCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(ExpandCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const ExpandCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Expand::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Expand::Capture(ExpandCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->logical_out_shape = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("logical_in_shape"));
ctx->logical_expand_shape =
JUST(composed_attrs.GetAttr<std::vector<int32_t>>("logical_expand_shape"));
return Maybe<void>::Ok();
}
Maybe<void> Expand::Apply(const ExpandCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("logical_out_shape", ctx->logical_out_shape));
JUST(attrs.SetAttr<std::vector<int32_t>>("logical_expand_shape", ctx->logical_expand_shape));
in_grads->at(0) = JUST(
functional::ExpandGrad(out_grads.at(0), ctx->logical_out_shape, ctx->logical_expand_shape));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("expand", Expand);
} // namespace one
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment