Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
struct FakeQuantizationCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class FakeQuantization : public OpExprGradFunction<FakeQuantizationCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(FakeQuantizationCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 3);
ctx->requires_grad = inputs.at(0)->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FakeQuantizationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(3);
if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fake_quantization", FakeQuantization);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/just.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
namespace oneflow {
namespace one {
struct FillCaptureState : public AutoGradCaptureState {
bool in_requires_grad = false;
bool value_requires_grad = false;
};
class Fill : public OpExprGradFunction<FillCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FillCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const FillCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Fill::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Fill::Capture(FillCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->in_requires_grad = inputs[0]->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Fill::Apply(const FillCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out_grads.size() must be equal to 1.";
in_grads->resize(1);
if (ctx->in_requires_grad) { (*in_grads)[0] = JUST(functional::Fill(out_grads[0], 0)); }
return Maybe<void>::Ok();
}
class FillTensor : public OpExprGradFunction<FillCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FillCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const FillCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> FillTensor::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> FillTensor::Capture(FillCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->in_requires_grad = inputs[0]->requires_grad();
ctx->value_requires_grad = inputs[1]->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> FillTensor::Apply(const FillCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out_grads.size() must be equal to 1.";
in_grads->resize(2);
if (ctx->value_requires_grad) {
int32_t num_axes = out_grads[0]->shape()->NumAxes();
std::vector<int32_t> axes_vec(num_axes);
std::iota(axes_vec.begin(), axes_vec.end(), 0);
(*in_grads)[1] = JUST(functional::ReduceSum(out_grads[0], axes_vec, /*keepdims=*/false));
}
if (ctx->in_requires_grad) { (*in_grads)[0] = JUST(functional::Fill(out_grads[0], 0)); }
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("fill_", Fill);
REGISTER_OP_EXPR_GRAD_FUNCTION("fill_tensor_", FillTensor);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FlattenCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class Flatten : public OpExprGradFunction<FlattenCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FlattenCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FlattenCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> Flatten::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
return Maybe<void>::Ok();
}
Maybe<void> Flatten::Capture(FlattenCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Flatten::Apply(const FlattenCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const auto& like = ctx->SavedTensors().at(0);
in_grads->resize(1);
in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), like));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("flatten", Flatten);
} // namespace one
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FlipCaptureState : public AutoGradCaptureState {
bool requires_grad;
std::vector<int32_t> dims;
};
class Flip : public OpExprGradFunction<FlipCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FlipCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const FlipCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Flip::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Flip::Capture(FlipCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dims = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dims"));
return Maybe<void>::Ok();
}
Maybe<void> Flip::Apply(const FlipCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) { (*in_grads)[0] = JUST(functional::Flip(out_grads[0], ctx->dims)); }
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("flip", Flip);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FoldInterpState : public AutoGradCaptureState {
bool requires_grad = true;
std::string data_format = "channels_first";
std::vector<int32_t> kernel_size;
std::vector<int32_t> dilation_rate;
std::vector<int32_t> padding;
std::vector<int32_t> strides;
};
class Fold : public OpExprGradFunction<FoldInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FoldInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const FoldInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Fold::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Fold::Capture(FoldInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation_rate"));
ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding"));
ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("strides"));
return Maybe<void>::Ok();
}
Maybe<void> Fold::Apply(const FoldInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
in_grads->at(0) = JUST(functional::Unfold(out_grads.at(0), ctx->data_format, ctx->kernel_size,
ctx->dilation_rate, ctx->padding, ctx->strides));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("fold", Fold);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedBiasAddDropoutInterpState : public AutoGradCaptureState {
bool input_requires_grad = true;
bool bias_requires_grad = true;
int32_t axis = 1;
float scale = 1.0;
};
class FusedBiasAddDropout : public OpExprGradFunction<FusedBiasAddDropoutInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FusedBiasAddDropoutInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FusedBiasAddDropoutInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> FusedBiasAddDropout::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> FusedBiasAddDropout::Capture(FusedBiasAddDropoutInterpState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 3);
ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input
ctx->bias_requires_grad = inputs.at(1)->requires_grad(); // bias
if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->scale = JUST(composed_attrs.GetAttr<float>("scale"));
ctx->axis = JUST(composed_attrs.GetAttr<int32_t>("axis"));
ctx->SaveTensorForBackward(inputs.at(2));
return Maybe<void>::Ok();
}
Maybe<void> FusedBiasAddDropout::Apply(const FusedBiasAddDropoutInterpState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }
// mask have no grad(reqiures_grad=False), but still take a place in in_grads
in_grads->resize(3);
const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& dropout_grad =
JUST(functional::DropoutGrad(out_grads.at(0), mask, ctx->scale));
if (ctx->input_requires_grad) { in_grads->at(0) = dropout_grad; }
const int64_t num_axes = out_grads.at(0)->shape()->NumAxes();
if (ctx->bias_requires_grad) {
std::vector<int32_t> reduce_axes_vec;
reduce_axes_vec.reserve(num_axes);
for (int i = 0; i < num_axes; ++i) {
if (i != ctx->axis) { reduce_axes_vec.emplace_back(i); }
}
in_grads->at(1) = JUST(functional::ReduceSum(dropout_grad, reduce_axes_vec, false));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_bias_add_mask_scale", FusedBiasAddDropout);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedBiasAddGeluInterpState : public AutoGradCaptureState {
bool input_requires_grad = true;
bool bias_requires_grad = true;
int32_t axis = 1;
};
class FusedBiasAddGelu : public OpExprGradFunction<FusedBiasAddGeluInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(FusedBiasAddGeluInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->bias_requires_grad = inputs.at(1)->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<int32_t>("axis"));
if (ctx->input_requires_grad || ctx->bias_requires_grad) {
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(inputs.at(1));
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedBiasAddGeluInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->input_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const int64_t num_axes = out_grads.at(0)->shape()->NumAxes();
in_grads->resize(2);
const auto& a = ctx->SavedTensors().at(0);
const auto& b = ctx->SavedTensors().at(1);
const std::shared_ptr<oneflow::one::Tensor>& fused_bias_add_gelu_grad =
JUST(functional::FusedBiasAddGeluGrad(a, b, out_grads.at(0), ctx->axis));
if (ctx->bias_requires_grad) {
std::vector<int32_t> reduce_axes_vec;
reduce_axes_vec.reserve(num_axes);
for (int i = 0; i < num_axes; ++i) {
if (i != ctx->axis) { reduce_axes_vec.emplace_back(i); }
}
in_grads->at(1) =
JUST(functional::ReduceSum(fused_bias_add_gelu_grad, reduce_axes_vec, false));
}
if (ctx->input_requires_grad) { in_grads->at(0) = fused_bias_add_gelu_grad; }
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_bias_add_gelu", FusedBiasAddGelu);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct FusedCrossFeatureInteractionInterpState : public AutoGradCaptureState {
bool x_requires_grad = true;
bool weight_requires_grad = true;
bool x0_requires_grad = true;
bool bias_requires_grad = true;
size_t x_idx = 0;
size_t bias_idx = 0;
size_t weight_idx = 0;
size_t x0_idx = 0;
size_t matmul_result_idx = 0;
std::string interaction_mode;
};
class FusedCrossFeatureInteraction
: public OpExprGradFunction<FusedCrossFeatureInteractionInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be None. ";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(FusedCrossFeatureInteractionInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 4) << "Input size should be equal to 4. ";
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->interaction_mode = JUST(composed_attrs.GetAttr<std::string>("interaction_mode"));
ctx->x_requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
ctx->weight_requires_grad = JUST(oneflow::VectorAt(inputs, 1))->requires_grad();
ctx->x_requires_grad = JUST(oneflow::VectorAt(inputs, 2))->requires_grad();
ctx->weight_requires_grad = JUST(oneflow::VectorAt(inputs, 3))->requires_grad();
ctx->x_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));
ctx->weight_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1)));
ctx->x0_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 2)));
if (ctx->interaction_mode == "matrix") {
ctx->bias_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 3)));
}
ctx->matmul_result_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 1)));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedCrossFeatureInteractionInterpState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 2) << "Out grads size should be equal to 2. ";
std::shared_ptr<oneflow::one::TensorTuple> grads;
in_grads->resize(4);
if (ctx->interaction_mode == "vector") {
grads = JUST(functional::FusedCrossFeatureInteractionV1Grad(
JUST(oneflow::VectorAt(out_grads, 0)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->weight_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x0_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->matmul_result_idx))));
} else if (ctx->interaction_mode == "matrix") {
grads = JUST(functional::FusedCrossFeatureInteractionV2Grad(
JUST(oneflow::VectorAt(out_grads, 0)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->weight_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->bias_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->x0_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->matmul_result_idx))));
} else {
UNIMPLEMENTED_THEN_RETURN() << "Interaction mode only support `vector` and `matrix`. ";
}
if (ctx->x_requires_grad) {
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(oneflow::VectorAt(*grads, 0));
}
if (ctx->weight_requires_grad) {
JUST(oneflow::VectorAt(*in_grads, 1)) = JUST(oneflow::VectorAt(*grads, 1));
}
if (ctx->x0_requires_grad) {
JUST(oneflow::VectorAt(*in_grads, 2)) = JUST(oneflow::VectorAt(*grads, 2));
}
if (ctx->bias_requires_grad) {
JUST(oneflow::VectorAt(*in_grads, 3)) = JUST(oneflow::VectorAt(*grads, 3));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_cross_feature_interaction", FusedCrossFeatureInteraction);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct FusedDotFeatureInteractionCaptureState : public AutoGradCaptureState {
bool need_grad_op = false;
std::vector<bool> features_requires_grad;
std::vector<int32_t> feature_dims;
int32_t output_concat_grad_dim = 0;
bool self_interaction = false;
bool has_output_concat = false;
bool has_output_concat_grad = false;
std::string pooling;
};
class FusedDotFeatureInteraction
: public OpExprGradFunction<FusedDotFeatureInteractionCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FusedDotFeatureInteractionCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FusedDotFeatureInteractionCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> FusedDotFeatureInteraction::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
return Maybe<void>::Ok();
}
Maybe<void> FusedDotFeatureInteraction::Capture(FusedDotFeatureInteractionCaptureState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->has_output_concat = JUST(attrs.GetAttr<bool>("has_output_concat"));
int32_t num_features = 0;
if (ctx->has_output_concat) {
num_features = inputs.size() - 1;
const auto& output_concat = JUST(oneflow::VectorAt(inputs, num_features));
ctx->has_output_concat_grad = output_concat->requires_grad();
ctx->output_concat_grad_dim = output_concat->shape()->At(1);
} else {
num_features = inputs.size();
}
if (ctx->has_output_concat_grad) { ctx->need_grad_op = true; }
ctx->features_requires_grad.resize(num_features);
ctx->feature_dims.resize(num_features);
for (int32_t i = 0; i < num_features; ++i) {
const auto& feature = JUST(oneflow::VectorAt(inputs, i));
ctx->features_requires_grad[i] = feature->requires_grad();
ctx->feature_dims[i] = feature->shape()->At(1);
if (feature->requires_grad()) { ctx->need_grad_op = true; }
ctx->SaveTensorForBackward(feature);
}
ctx->pooling = JUST(attrs.GetAttr<std::string>("pooling"));
if (!ctx->need_grad_op) { return Maybe<void>::Ok(); }
ctx->self_interaction = JUST(attrs.GetAttr<bool>("self_interaction"));
return Maybe<void>::Ok();
}
Maybe<void> FusedDotFeatureInteraction::Apply(const FusedDotFeatureInteractionCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->need_grad_op) { return Maybe<void>::Ok(); }
int32_t num_features = ctx->features_requires_grad.size();
in_grads->resize(num_features + 1);
TensorTuple features(num_features);
for (int i = 0; i < num_features; ++i) {
features[i] = JUST(oneflow::VectorAt(ctx->SavedTensors(), i));
}
std::shared_ptr<oneflow::one::TensorTuple> grads;
grads = JUST(functional::FusedDotFeatureInteractionGrad(
JUST(oneflow::VectorAt(out_grads, 0)), features, ctx->has_output_concat,
ctx->self_interaction, ctx->output_concat_grad_dim, ctx->pooling));
for (int32_t i = 0; i < num_features; ++i) {
if (JUST(oneflow::VectorAt(ctx->features_requires_grad, i))) {
JUST(oneflow::VectorAt(*in_grads, i)) = JUST(oneflow::VectorAt(*grads, i));
}
}
if (ctx->has_output_concat_grad) {
JUST(oneflow::VectorAt(*in_grads, num_features)) =
JUST(oneflow::VectorAt(*grads, num_features));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_dot_feature_interaction", FusedDotFeatureInteraction);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedGruCellGradCaptureState : public AutoGradCaptureState {
bool has_bias = true;
bool hx_needs_grad = true;
};
class FusedGruCellGrad : public OpExprGradFunction<FusedGruCellGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "FusedGruCellGrad::Init forward op expr is null.";
return Maybe<void>::Ok();
}
Maybe<void> Capture(FusedGruCellGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
const size_t in_size = inputs.size();
CHECK_OR_RETURN(in_size == 3 || in_size == 5)
<< "FusedGruCellGrad::Capture(): input tensor size must be 3 or 5";
ctx->has_bias = in_size == 5;
ctx->hx_needs_grad = inputs[2]->requires_grad();
ctx->SaveTensorForBackward(outputs[1]); // workspace
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedGruCellGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& workspace = ctx->SavedTensors()[0]; // workspace
const auto& grad_hy = out_grads[0];
const auto& results =
JUST(functional::FusedGruCellGrad(grad_hy, workspace, ctx->has_bias, ctx->hx_needs_grad));
if (ctx->has_bias) {
in_grads->resize(5);
} else {
in_grads->resize(3);
}
(*in_grads)[0] = (*results)[0];
(*in_grads)[1] = (*results)[1];
if (ctx->hx_needs_grad) { (*in_grads)[2] = (*results)[2]; }
if (ctx->has_bias) {
if (ctx->hx_needs_grad) {
(*in_grads)[3] = (*results)[3];
(*in_grads)[4] = (*results)[4];
} else {
(*in_grads)[3] = (*results)[2];
(*in_grads)[4] = (*results)[3];
}
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_gru_cell", FusedGruCellGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedLstmCellGradCaptureState : public AutoGradCaptureState {
bool has_bias = true;
bool need_grad_cx = true;
};
class FusedLstmCellGrad : public OpExprGradFunction<FusedLstmCellGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "FusedLstmCellGrad::Init forward op expr is null.";
return Maybe<void>::Ok();
}
Maybe<void> Capture(FusedLstmCellGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
const size_t in_size = inputs.size();
CHECK_OR_RETURN(in_size == 3 || in_size == 5)
<< "FusedLstmCellGrad::Capture(): input tensor size must be 3 or 5";
ctx->has_bias = in_size == 5;
ctx->need_grad_cx = inputs[2]->requires_grad();
ctx->SaveTensorForBackward(inputs[2]); // cx
ctx->SaveTensorForBackward(outputs[1]); // cy
ctx->SaveTensorForBackward(outputs[2]); // workspace
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedLstmCellGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& cx = ctx->SavedTensors()[0]; // cx
const auto& cy = ctx->SavedTensors()[1]; // cy
const auto& workspace = ctx->SavedTensors()[2]; // workspace
const auto& grad_hy = out_grads[0];
const auto& grad_cy = out_grads[1];
const auto& results = JUST(functional::FusedLstmCellGrad(grad_hy, grad_cy, cx, cy, workspace,
ctx->need_grad_cx, ctx->has_bias));
if (ctx->has_bias) {
in_grads->resize(5);
} else {
in_grads->resize(3);
}
(*in_grads)[0] = (*results)[0];
(*in_grads)[1] = (*results)[0];
if (ctx->need_grad_cx) { (*in_grads)[2] = (*results)[1]; }
if (ctx->has_bias) {
if (ctx->need_grad_cx) {
(*in_grads)[3] = (*results)[2];
(*in_grads)[4] = (*results)[2];
} else {
(*in_grads)[3] = (*results)[1];
(*in_grads)[4] = (*results)[1];
}
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_lstm_cell", FusedLstmCellGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/error.pb.h"
#include "oneflow/core/common/just.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#if CUDA_VERSION >= 11060
namespace oneflow {
namespace one {
struct FusedMatmulBiasAddReluDropoutCaptureState : public AutoGradCaptureState {
int32_t weight_num = 0;
bool skip_final_activation = false;
bool x_requires_grad = false;
std::vector<bool> weights_requires_grad;
std::vector<bool> biases_requires_grad;
std::vector<float> dropout_rate_list;
};
class FusedMatmulBiasAddReluDropout
: public OpExprGradFunction<FusedMatmulBiasAddReluDropoutCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FusedMatmulBiasAddReluDropoutCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FusedMatmulBiasAddReluDropoutCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override;
protected:
AttrMap base_attrs_;
};
Maybe<void> FusedMatmulBiasAddReluDropout::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> FusedMatmulBiasAddReluDropout::Capture(FusedMatmulBiasAddReluDropoutCaptureState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
CHECK_OR_RETURN(inputs.size() % 2 == 1) << "Both weight and bias should be passed together. ";
int32_t weight_num = (inputs.size() - 1) / 2;
ctx->weight_num = weight_num;
ctx->x_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
ctx->weights_requires_grad.resize(weight_num);
ctx->biases_requires_grad.resize(weight_num);
for (int32_t i = 0; i < weight_num; i++) {
ctx->weights_requires_grad.at(i) = inputs.at(i + 1)->requires_grad(); // NOLINT
ctx->biases_requires_grad.at(i) = inputs.at(i + 1 + weight_num)->requires_grad(); // NOLINT
}
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // x. idx_sum:1
for (int32_t i = 0; i < weight_num; i++) {
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, i + 1))); // weights. idx_sum:1+w
}
ctx->SaveTensorForBackward(JUST(VectorAt(outputs, 0))); // final layers output. idx_sum:2+w
for (int32_t i = 0; i < weight_num; i++) {
ctx->SaveTensorForBackward(
JUST(VectorAt(outputs, i + 1))); // cublas aux. need minus 1. idx_sum:2+2w
}
for (int32_t i = 0; i < weight_num - 1; i++) {
ctx->SaveTensorForBackward(JUST(VectorAt(outputs, i + 1 + weight_num))); // hidden.
}
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->skip_final_activation = JUST(composed_attrs.GetAttr<bool>("skip_final_activation"));
ctx->dropout_rate_list = JUST(composed_attrs.GetAttr<std::vector<float>>("dropout_rate_list"));
return Maybe<void>::Ok();
}
Maybe<void> FusedMatmulBiasAddReluDropout::Apply(
const FusedMatmulBiasAddReluDropoutCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
int32_t weight_num = ctx->weight_num;
in_grads->resize(1 + 2 * weight_num);
TensorTuple hiddens(weight_num - 1);
TensorTuple weights(weight_num);
TensorTuple cublas_auxs(weight_num);
TensorTuple dgrad(weight_num);
std::shared_ptr<one::Tensor> x = JUST(VectorAt(ctx->SavedTensors(), 0));
std::shared_ptr<one::Tensor> out = JUST(VectorAt(ctx->SavedTensors(), 1 + weight_num));
for (int32_t i = 0; i < weight_num; ++i) {
weights[i] = JUST(VectorAt(ctx->SavedTensors(), 1 + i));
}
for (int32_t i = 0; i < weight_num; ++i) {
cublas_auxs[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + weight_num));
}
for (int32_t i = 0; i < weight_num - 1; ++i) {
hiddens[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + 2 * weight_num));
}
float rate = ctx->dropout_rate_list.at(weight_num - 1);
float scale = 0.0f;
if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); }
/*
step1: use dy and mask to get last layer's dropout + relu grad.
Because curand_uniform distribution is (0.0, 1.0], so the value after relu will be write into mask
too. And DropoutGrad use this mask to generate grad, it will generate dropout and relu grad
simultaneously.
*/
std::shared_ptr<one::Tensor> last_bias_dy = JUST(VectorAt(out_grads, 0));
if (!ctx->skip_final_activation || rate != 0.0f) {
last_bias_dy = JUST(functional::FusedReluDropoutGrad(JUST(VectorAt(out_grads, 0)),
cublas_auxs[weight_num - 1], scale));
}
// step2: use reduce_sum to get last layer's bias grad.
std::vector<int32_t> reduce_axes_vec{0};
if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) {
JUST(VectorAt(*in_grads, 2 * weight_num)) =
JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false));
}
std::shared_ptr<one::Tensor> cublas_dy = last_bias_dy;
for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) {
// If it is final layer, we use out_grads[0] as dy.
if (hidden_layer_idx != weight_num - 1) {
cublas_dy = JUST(VectorAt(dgrad, hidden_layer_idx + 1));
}
rate = ctx->dropout_rate_list.at(hidden_layer_idx - 1);
scale = 1.0;
if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); }
/*
Here we use cublas to compute bias + relu + matmul grad.
Then use Matmul to compute weight grad.
*/
const auto& matmul_relu_bias_bgrad = JUST(functional::CublasBiasAddReluMatmulGrad(
cublas_dy, JUST(VectorAt(weights, hidden_layer_idx)),
JUST(VectorAt(cublas_auxs, hidden_layer_idx - 1)), /*alpha=*/scale));
// dgrad
dgrad.at(hidden_layer_idx) = matmul_relu_bias_bgrad->at(0); // NOLINT
if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx - 1)))) {
// dbias
JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx)) =
matmul_relu_bias_bgrad->at(1); // NOLINT
}
// dw
if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {
JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = JUST(functional::MatMul(
cublas_dy, JUST(VectorAt(hiddens, hidden_layer_idx - 1)), true, false, 1.0));
}
}
// For the first layer, we need to use 2 matmul to get grads.
std::shared_ptr<one::Tensor> last_dy;
if (weight_num != 1) {
last_dy = JUST(VectorAt(dgrad, 1));
} else {
last_dy = last_bias_dy;
}
if (ctx->x_requires_grad) {
// dx:
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::MatMul(last_dy, JUST(VectorAt(weights, 0)), false, false, 1.0));
}
if (JUST(VectorAt(ctx->weights_requires_grad, 0))) {
// dw:
JUST(VectorAt(*in_grads, 1)) =
JUST(functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_matmul_bias_add_relu_dropout", FusedMatmulBiasAddReluDropout);
} // namespace one
} // namespace oneflow
#endif // CUDA_VERSION >= 11060
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedScaleMaskSoftmaxInterState : public AutoGradCaptureState {
bool input_requires_grad = false;
float scale = 1.0;
};
class FusedScaleMaskSoftmax : public OpExprGradFunction<FusedScaleMaskSoftmaxInterState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FusedScaleMaskSoftmaxInterState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FusedScaleMaskSoftmaxInterState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> FusedScaleMaskSoftmax::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> FusedScaleMaskSoftmax::Capture(FusedScaleMaskSoftmaxInterState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // input, mask
ctx->input_requires_grad = inputs.at(0)->requires_grad();
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->scale = JUST(composed_attrs.GetAttr<float>("scale_value"));
ctx->SaveTensorForBackward(inputs.at(1)); // save mask
ctx->SaveTensorForBackward(outputs.at(0)); // save y, ie. softmax result
return Maybe<void>::Ok();
}
Maybe<void> FusedScaleMaskSoftmax::Apply(const FusedScaleMaskSoftmaxInterState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // dy
in_grads->resize(2); // input, mask
const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& y = ctx->SavedTensors().at(1);
const std::shared_ptr<oneflow::one::Tensor>& fused_scale_mask_softmax_grad =
JUST(functional::FusedScaleMaskSoftmaxGrad(y, out_grads.at(0), mask, ctx->scale));
in_grads->at(0) = fused_scale_mask_softmax_grad;
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_scale_mask_softmax", FusedScaleMaskSoftmax);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedScaleMaskSoftmaxDropoutInterState : public AutoGradCaptureState {
bool input_requires_grad = true;
float scale = 1.0;
float dropout_scale = 1.0;
};
class FusedScaleMaskSoftmaxDropout
: public OpExprGradFunction<FusedScaleMaskSoftmaxDropoutInterState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FusedScaleMaskSoftmaxDropoutInterState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FusedScaleMaskSoftmaxDropoutInterState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> FusedScaleMaskSoftmaxDropout::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> FusedScaleMaskSoftmaxDropout::Capture(FusedScaleMaskSoftmaxDropoutInterState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 3); // input, mask, dropout_mask
ctx->input_requires_grad = inputs.at(0)->requires_grad();
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->scale = JUST(composed_attrs.GetAttr<float>("scale_value"));
ctx->dropout_scale = JUST(composed_attrs.GetAttr<float>("dropout_scale_value"));
ctx->SaveTensorForBackward(inputs.at(1)); // mask
ctx->SaveTensorForBackward(inputs.at(2)); // dropout_mask
ctx->SaveTensorForBackward(outputs.at(1)); // softmax_y
return Maybe<void>::Ok();
}
Maybe<void> FusedScaleMaskSoftmaxDropout::Apply(const FusedScaleMaskSoftmaxDropoutInterState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // dy, d_softmax_y
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
in_grads->resize(3); // input, mask, dropout_mask
const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& dropout_mask = ctx->SavedTensors().at(1);
const std::shared_ptr<oneflow::one::Tensor>& softmax_y = ctx->SavedTensors().at(2);
const std::shared_ptr<oneflow::one::Tensor>& input_grad =
JUST(functional::FusedScaleMaskSoftmaxDropoutGrad(
softmax_y, out_grads.at(0), mask, dropout_mask, ctx->scale, ctx->dropout_scale));
in_grads->at(0) = input_grad;
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_scale_mask_softmax_dropout", FusedScaleMaskSoftmaxDropout);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedScaleTrilState : public AutoGradCaptureState {
bool requires_grad;
int64_t diagonal;
double floating_scale_value;
int64_t integer_scale_value;
bool is_floating_scale_value;
};
class FusedScaleTril : public OpExprGradFunction<FusedScaleTrilState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FusedScaleTrilState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FusedScaleTrilState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> FusedScaleTril::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> FusedScaleTril::Capture(FusedScaleTrilState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->diagonal = JUST(composed_attrs.GetAttr<int64_t>("diagonal"));
ctx->floating_scale_value = JUST(composed_attrs.GetAttr<double>("floating_scale_value"));
ctx->integer_scale_value = JUST(composed_attrs.GetAttr<int64_t>("integer_scale_value"));
ctx->is_floating_scale_value = JUST(composed_attrs.GetAttr<bool>("is_floating_scale_value"));
return Maybe<void>::Ok();
}
Maybe<void> FusedScaleTril::Apply(const FusedScaleTrilState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Scalar scale;
if (ctx->is_floating_scale_value) {
scale = ctx->floating_scale_value;
} else {
scale = ctx->integer_scale_value;
}
(*in_grads)[0] = JUST(functional::FusedScaleTril(out_grads[0], ctx->diagonal, 0, scale));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_scale_tril", FusedScaleTril);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedScaleTrilSoftmaxMaskScaleInterpState : public AutoGradCaptureState {
bool input_requires_grad = true;
int64_t diagonal = 0;
float tril_scale_value = 0.0;
float mask_scale_value = 1.0;
};
class FusedScaleTrilSoftmaxMaskScale
: public OpExprGradFunction<FusedScaleTrilSoftmaxMaskScaleInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FusedScaleTrilSoftmaxMaskScaleInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FusedScaleTrilSoftmaxMaskScaleInterpState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> FusedScaleTrilSoftmaxMaskScale::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> FusedScaleTrilSoftmaxMaskScale::Capture(FusedScaleTrilSoftmaxMaskScaleInterpState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->diagonal = JUST(composed_attrs.GetAttr<int64_t>("diagonal"));
ctx->tril_scale_value = JUST(composed_attrs.GetAttr<float>("tril_scale_value"));
ctx->mask_scale_value = JUST(composed_attrs.GetAttr<float>("mask_scale_value"));
ctx->SaveTensorForBackward(inputs.at(1)); // Save Mask
ctx->SaveTensorForBackward(outputs.at(1)); // Save softmax_y
return Maybe<void>::Ok();
}
Maybe<void> FusedScaleTrilSoftmaxMaskScale::Apply(
const FusedScaleTrilSoftmaxMaskScaleInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // Cause output has y and softmax_y
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
// mask have no grad(reqiures_grad=False), but still take a place in in_grads
in_grads->resize(2);
const std::shared_ptr<oneflow::one::Tensor>& mask = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& softmax_y = ctx->SavedTensors().at(1);
const std::shared_ptr<oneflow::one::Tensor>& input_grad =
JUST(functional::FusedScaleTrilSoftmaxMaskScaleGrad(softmax_y, out_grads.at(0), mask,
ctx->diagonal, ctx->tril_scale_value,
ctx->mask_scale_value));
if (ctx->input_requires_grad) { in_grads->at(0) = input_grad; }
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_tril_scale_softmax_mask_scale",
FusedScaleTrilSoftmaxMaskScale);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedSelfAttentionInterpState : public AutoGradCaptureState {
bool input_requires_grad = false;
float alpha = 1.0;
};
class FusedSelfAttention : public OpExprGradFunction<FusedSelfAttentionInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(FusedSelfAttentionInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->input_requires_grad = inputs.at(0)->requires_grad();
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->alpha = JUST(composed_attrs.GetAttr<float>("alpha"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedSelfAttentionInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 2);
in_grads->resize(1);
const auto& hidden_states = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& fused_self_attention_grad =
JUST(functional::FusedSelfAttentionGrad(out_grads.at(0), out_grads.at(1), hidden_states,
ctx->alpha));
in_grads->at(0) = fused_self_attention_grad;
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_self_attention_query_mul_key_and_value", FusedSelfAttention);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct GatherCaptureState : public AutoGradCaptureState {
int64_t axis;
bool requires_grad;
};
class Gather : public OpExprGradFunction<GatherCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(GatherCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const GatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Gather::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Gather::Capture(GatherCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(inputs.at(1));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<int64_t>("axis"));
return Maybe<void>::Ok();
}
Maybe<void> Gather::Apply(const GatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& x = ctx->SavedTensors().at(0);
const auto& indices = ctx->SavedTensors().at(1);
in_grads->at(0) =
JUST(functional::UnsortedSegmentSumLike(out_grads.at(0), indices, x, ctx->axis));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("gather", Gather);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct GatherNdCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class GatherNd : public OpExprGradFunction<GatherNdCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(GatherNdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (ctx->requires_grad) {
ctx->SaveTensorForBackward(inputs.at(0)); // params
ctx->SaveTensorForBackward(inputs.at(1)); // indices
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const GatherNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->requires_grad) {
const auto& params = ctx->SavedTensors().at(0);
const auto& indices = ctx->SavedTensors().at(1);
in_grads->at(0) = JUST(functional::ScatterNdLike(params, out_grads.at(0), indices));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("gather_nd", GatherNd);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct GridSampleInterpState : public AutoGradCaptureState {
std::string interpolation_mode = "";
std::string padding_mode = "";
bool align_corners = false;
size_t input_index = -1;
size_t grid_index = -1;
bool input_requires_grad = false;
bool grid_requires_grad = false;
bool requires_grad = false;
};
class GridSample : public OpExprGradFunction<GridSampleInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(GridSampleInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->grid_requires_grad = inputs.at(1)->requires_grad();
ctx->requires_grad = ctx->input_requires_grad || ctx->grid_requires_grad;
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->grid_index = ctx->SaveTensorForBackward(inputs.at(1)); // grid
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->interpolation_mode = JUST(composed_attrs.GetAttr<std::string>("interpolation_mode"));
ctx->padding_mode = JUST(composed_attrs.GetAttr<std::string>("padding_mode"));
ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const GridSampleInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& input = ctx->SavedTensors().at(ctx->input_index);
const auto& grid = ctx->SavedTensors().at(ctx->grid_index);
const auto& results =
JUST(functional::GridSampleGrad(out_grads.at(0), input, grid, ctx->interpolation_mode,
ctx->padding_mode, ctx->align_corners));
in_grads->resize(2);
if (ctx->input_requires_grad) { in_grads->at(0) = results->at(0); }
if (ctx->grid_requires_grad) { in_grads->at(1) = results->at(1); }
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("grid_sample", GridSample);
} // namespace one
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment