Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
......@@ -81,7 +81,7 @@ Maybe<void> CublasFusedMLP::Capture(CublasFusedMLPCaptureState* ctx, const Tenso
ctx->SaveTensorForBackward(
JUST(VectorAt(outputs, i + 1))); // cublas aux. need minus 1. idx_sum:2+2w
}
for (int32_t i = 0; i < weight_num - 1; i++) {
for (int32_t i = 0; i < weight_num; i++) {
ctx->SaveTensorForBackward(JUST(VectorAt(outputs, i + 1 + weight_num))); // hidden.
}
......@@ -103,14 +103,7 @@ Maybe<void> CublasFusedMLP::Apply(const CublasFusedMLPCaptureState* ctx,
JUST(VectorAt(ctx->SavedTensors(), 1 + weight_num))));
}
// step2: use reduce_sum to get last layer's bias grad.
std::vector<int32_t> reduce_axes_vec{0};
if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) {
JUST(VectorAt(*in_grads, 2 * weight_num)) =
JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false));
}
TensorTuple hiddens(weight_num - 1);
TensorTuple hiddens(weight_num);
TensorTuple weights(weight_num);
TensorTuple cublas_auxs(weight_num);
TensorTuple dgrad(weight_num);
......@@ -125,11 +118,44 @@ Maybe<void> CublasFusedMLP::Apply(const CublasFusedMLPCaptureState* ctx,
cublas_auxs[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + weight_num));
}
for (int32_t i = 0; i < weight_num - 1; ++i) {
for (int32_t i = 0; i < weight_num; ++i) {
hiddens[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + 2 * weight_num));
}
std::shared_ptr<one::Tensor> cublas_dy = last_bias_dy;
// Use Fully Fused MLP Backward.
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD", false)) {
const std::vector<float> alpha_list(weight_num - 1, 1.0);
const auto& fused_mlp_grad =
JUST(functional::FusedMLPGrad(cublas_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), weights,
cublas_auxs, hiddens, alpha_list));
if (ctx->x_requires_grad) {
// dx:
JUST(VectorAt(*in_grads, 0)) = fused_mlp_grad->at(0);
}
for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > -1; hidden_layer_idx--) {
if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx)))) {
// dbias
JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx + 1)) =
fused_mlp_grad->at(1 + hidden_layer_idx); // NOLINT
}
// dw
if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {
JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) =
fused_mlp_grad->at(1 + weight_num + hidden_layer_idx);
}
}
} else {
// step2: use reduce_sum to get last layer's bias grad.
std::vector<int32_t> reduce_axes_vec{0};
if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) {
JUST(VectorAt(*in_grads, 2 * weight_num)) =
JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false));
}
for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) {
// If it is final layer, we use out_grads[0] as dy.
if (hidden_layer_idx != weight_num - 1) {
......@@ -173,8 +199,9 @@ Maybe<void> CublasFusedMLP::Apply(const CublasFusedMLPCaptureState* ctx,
}
if (JUST(VectorAt(ctx->weights_requires_grad, 0))) {
// dw:
JUST(VectorAt(*in_grads, 1)) =
JUST(functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0));
JUST(VectorAt(*in_grads, 1)) = JUST(
functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0));
}
}
return Maybe<void>::Ok();
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct DeformConvNdCaptureState : public AutoGradCaptureState {
bool input_requires_grad = false;
bool offset_requires_grad = false;
bool weight_requires_grad = false;
bool mask_requires_grad = false;
bool bias_requires_grad = false;
int32_t stride_h = 0;
int32_t stride_w = 0;
int32_t pad_h = 0;
int32_t pad_w = 0;
int32_t dilation_h = 0;
int32_t dilation_w = 0;
int32_t groups = 0;
int32_t offset_groups = 0;
bool use_mask = false;
};
class DeformConvNd : public OpExprGradFunction<DeformConvNdCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DeformConvNdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DeformConvNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> DeformConvNd::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> DeformConvNd::Capture(DeformConvNdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->weight_requires_grad = inputs.at(1)->requires_grad();
ctx->offset_requires_grad = inputs.at(2)->requires_grad();
ctx->mask_requires_grad = inputs.at(3)->requires_grad();
ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // weight
ctx->SaveTensorForBackward(inputs.at(2)); // offset
ctx->SaveTensorForBackward(inputs.at(3)); // mask
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->use_mask = JUST(composed_attrs.GetAttr<bool>("use_mask"));
ctx->stride_h = JUST(composed_attrs.GetAttr<int32_t>("stride_h"));
ctx->stride_w = JUST(composed_attrs.GetAttr<int32_t>("stride_w"));
ctx->pad_h = JUST(composed_attrs.GetAttr<int32_t>("pad_h"));
ctx->pad_w = JUST(composed_attrs.GetAttr<int32_t>("pad_w"));
ctx->dilation_h = JUST(composed_attrs.GetAttr<int32_t>("dilation_h"));
ctx->dilation_w = JUST(composed_attrs.GetAttr<int32_t>("dilation_w"));
ctx->groups = JUST(composed_attrs.GetAttr<int32_t>("groups"));
ctx->offset_groups = JUST(composed_attrs.GetAttr<int32_t>("offset_groups"));
return Maybe<void>::Ok();
}
Maybe<void> DeformConvNd::Apply(const DeformConvNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
in_grads->resize(5);
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& input = ctx->SavedTensors().at(0);
const auto& weight = ctx->SavedTensors().at(1);
const auto& offset = ctx->SavedTensors().at(2);
const auto& mask = ctx->SavedTensors().at(3);
const auto& output_grad = out_grads.at(0);
if (ctx->input_requires_grad || ctx->offset_requires_grad || ctx->mask_requires_grad) {
std::shared_ptr<TensorTuple> grads_tuple;
if (ctx->use_mask) {
grads_tuple = JUST(functional::DeformConv2dInputGrad(
output_grad, input, weight, offset, mask, ctx->stride_h, ctx->stride_w, ctx->pad_h,
ctx->pad_w, ctx->dilation_h, ctx->dilation_w, ctx->groups, ctx->offset_groups,
ctx->use_mask));
} else {
grads_tuple = JUST(functional::DeformConv2dInputGrad(
output_grad, input, weight, offset, NullOpt, ctx->stride_h, ctx->stride_w, ctx->pad_h,
ctx->pad_w, ctx->dilation_h, ctx->dilation_w, ctx->groups, ctx->offset_groups,
ctx->use_mask));
}
if (ctx->input_requires_grad) {
in_grads->at(0) = grads_tuple->at(0); // input_grad
}
if (ctx->offset_requires_grad) {
in_grads->at(2) = grads_tuple->at(1); // offset_grad
}
if (ctx->use_mask && ctx->mask_requires_grad) {
in_grads->at(3) = grads_tuple->at(2); // mask_grad
}
}
if (ctx->weight_requires_grad) { // weight_grad
in_grads->at(1) = JUST(functional::DeformConv2dParamGrad(
output_grad, input, weight, offset, mask, ctx->stride_h, ctx->stride_w, ctx->pad_h,
ctx->pad_w, ctx->dilation_h, ctx->dilation_w, ctx->groups, ctx->offset_groups,
ctx->use_mask));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("deform_conv2d", DeformConvNd);
} // namespace one
} // namespace oneflow
\ No newline at end of file
......@@ -26,10 +26,9 @@ struct DimScatterCaptureState : public AutoGradCaptureState {
bool input_requires_grad;
bool src_requires_grad;
};
enum class ScatterType { kUpdate, kAdd, kMultiply };
enum SCATTER_TYPE { SCATTER_UPDATE, SCATTER_ADD };
template<SCATTER_TYPE T>
template<ScatterType T>
class DimScatter : public OpExprGradFunction<DimScatterCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
......@@ -37,14 +36,12 @@ class DimScatter : public OpExprGradFunction<DimScatterCaptureState> {
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
Maybe<void> ApplyCommon(const DimScatterCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const;
private:
AttrMap base_attrs_;
};
template<SCATTER_TYPE T>
template<ScatterType T>
Maybe<void> DimScatter<T>::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
......@@ -52,7 +49,7 @@ Maybe<void> DimScatter<T>::Init(const OpExpr& op) {
return Maybe<void>::Ok();
}
template<SCATTER_TYPE T>
template<ScatterType T>
Maybe<void> DimScatter<T>::Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
......@@ -63,52 +60,43 @@ Maybe<void> DimScatter<T>::Capture(DimScatterCaptureState* ctx, const TensorTupl
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(1)); // index saved
if (T == ScatterType::kMultiply) {
ctx->SaveTensorForBackward(inputs.at(2)); // src saved
}
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int32_t>("dim"));
return Maybe<void>::Ok();
}
template<SCATTER_TYPE T>
Maybe<void> DimScatter<T>::ApplyCommon(const DimScatterCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
template<ScatterType T>
Maybe<void> DimScatter<T>::Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(3);
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
if (ctx->src_requires_grad) {
in_grads->at(2) = JUST(functional::DimGather(out_grads.at(0), ctx->dim, index, false));
}
return Maybe<void>::Ok();
}
template<>
Maybe<void> DimScatter<SCATTER_TYPE::SCATTER_UPDATE>::Apply(const DimScatterCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
JUST(ApplyCommon(ctx, out_grads, in_grads));
if (ctx->input_requires_grad) {
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
in_grads->at(0) =
JUST(functional::DimScatterUpdateScalar(out_grads.at(0), ctx->dim, index, 0.0f));
}
return Maybe<void>::Ok();
}
template<>
Maybe<void> DimScatter<SCATTER_TYPE::SCATTER_ADD>::Apply(const DimScatterCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
JUST(ApplyCommon(ctx, out_grads, in_grads));
if (T == ScatterType::kAdd) { in_grads->at(0) = out_grads.at(0); }
if (ctx->input_requires_grad) { in_grads->at(0) = out_grads.at(0); }
if (T == ScatterType::kUpdate) {
in_grads->at(0) = JUST(functional::DimScatterUpdateScalar(out_grads.at(0), ctx->dim, index,
0.0f, /*inplace*/ false));
}
if (T == ScatterType::kMultiply) {
const std::shared_ptr<oneflow::one::Tensor>& src = ctx->SavedTensors().at(1);
in_grads->at(0) =
JUST(functional::DimScatterMul(out_grads.at(0), ctx->dim, index, src, /*inplace*/ false));
}
}
return Maybe<void>::Ok();
}
......@@ -156,18 +144,15 @@ Maybe<void> DimScatterUpdateScalar::Apply(const DimScatterCaptureState* ctx,
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
in_grads->resize(2);
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", ctx->dim));
JUST(attrs.SetAttr<float>("src_scalar", 0.0f));
in_grads->at(0) =
JUST(functional::DimScatterUpdateScalar(out_grads.at(0), ctx->dim, index, 0.0f));
in_grads->at(0) = JUST(functional::DimScatterUpdateScalar(out_grads.at(0), ctx->dim, index, 0.0f,
/*inplace*/ false));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update", DimScatter<SCATTER_TYPE::SCATTER_UPDATE>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_add", DimScatter<SCATTER_TYPE::SCATTER_ADD>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update", DimScatter<ScatterType::kUpdate>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_add", DimScatter<ScatterType::kAdd>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_mul", DimScatter<ScatterType::kMultiply>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update_scalar", DimScatterUpdateScalar);
} // namespace one
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
namespace oneflow {
namespace one {
namespace {
Maybe<one::UserOpExpr> EagerCclReduce(Symbol<ParallelDesc> parallel_desc, int64_t root) {
return one::OpBuilder("eager_ccl_reduce", *JUST(UniqueStr("eager_ccl_reduce")))
.Input("in")
.Output("out")
.Attr<std::string>("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf()))
.Attr<int64_t>("root", root)
.Build();
}
Maybe<one::UserOpExpr> FindOrCreatEagerCclReduceOpExpr(Symbol<ParallelDesc> parallel_desc,
int64_t root) {
thread_local HashMap<std::pair<Symbol<ParallelDesc>, int64_t>, std::shared_ptr<one::UserOpExpr>>
parallel_desc_and_root_device2eager_nccl_reduce;
const auto& key = std::make_pair(parallel_desc, root);
auto iter = parallel_desc_and_root_device2eager_nccl_reduce.find(key);
if (iter == parallel_desc_and_root_device2eager_nccl_reduce.end()) {
std::shared_ptr<UserOpExpr> op_expr = JUST(EagerCclReduce(parallel_desc, root));
iter = parallel_desc_and_root_device2eager_nccl_reduce.emplace(key, op_expr).first;
}
return iter->second;
}
} // namespace
struct EagerCclBroadcastCaptureState : public AutoGradCaptureState { // NOLINT
Symbol<ParallelDesc> parallel_desc;
int64_t root;
};
class EagerCclBroadcast : public OpExprGradFunction<EagerCclBroadcastCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
return Maybe<void>::Ok();
}
Maybe<void> Capture(EagerCclBroadcastCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
ctx->root = JUST(interp_ctx.attrs.GetAttr<int64_t>("root"));
ctx->parallel_desc = JUST(interp_ctx.parallel_desc);
return Maybe<void>::Ok();
}
Maybe<void> Apply(const EagerCclBroadcastCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& grad_op = JUST(FindOrCreatEagerCclReduceOpExpr(ctx->parallel_desc, ctx->root));
in_grads->resize(1);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op, {out_grads.at(0)}));
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("eager_ccl_broadcast", EagerCclBroadcast);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
namespace oneflow {
namespace one {
namespace {
Maybe<one::UserOpExpr> EagerNcclReduce(Symbol<ParallelDesc> parallel_desc, int64_t root) {
return one::OpBuilder("eager_nccl_reduce", *JUST(UniqueStr("eager_nccl_reduce")))
.Input("in")
.Output("out")
.Attr<std::string>("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf()))
.Attr<int64_t>("root", root)
.Build();
}
Maybe<one::UserOpExpr> FindOrCreatEagerNcclReduceOpExpr(Symbol<ParallelDesc> parallel_desc,
int64_t root) {
thread_local HashMap<std::pair<Symbol<ParallelDesc>, int64_t>, std::shared_ptr<one::UserOpExpr>>
parallel_desc_and_root_device2eager_nccl_reduce;
const auto& key = std::make_pair(parallel_desc, root);
auto iter = parallel_desc_and_root_device2eager_nccl_reduce.find(key);
if (iter == parallel_desc_and_root_device2eager_nccl_reduce.end()) {
std::shared_ptr<UserOpExpr> op_expr = JUST(EagerNcclReduce(parallel_desc, root));
iter = parallel_desc_and_root_device2eager_nccl_reduce.emplace(key, op_expr).first;
}
return iter->second;
}
} // namespace
struct EagerNcclBroadcastCaptureState : public AutoGradCaptureState {
Symbol<ParallelDesc> parallel_desc;
int64_t root;
};
class EagerNcclBroadcast : public OpExprGradFunction<EagerNcclBroadcastCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
return Maybe<void>::Ok();
}
Maybe<void> Capture(EagerNcclBroadcastCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
ctx->root = JUST(interp_ctx.attrs.GetAttr<int64_t>("root"));
ctx->parallel_desc = JUST(interp_ctx.parallel_desc);
return Maybe<void>::Ok();
}
Maybe<void> Apply(const EagerNcclBroadcastCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& grad_op = JUST(FindOrCreatEagerNcclReduceOpExpr(ctx->parallel_desc, ctx->root));
in_grads->resize(1);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op, {out_grads.at(0)}));
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("eager_nccl_broadcast", EagerNcclBroadcast);
} // namespace one
} // namespace oneflow
......@@ -21,9 +21,10 @@ namespace oneflow {
namespace one {
struct ExpandCaptureState : public AutoGradCaptureState {
std::vector<int32_t> logical_out_shape;
std::vector<int32_t> logical_expand_shape;
bool requires_grad;
int32_t lpad;
bool keep_dims;
std::vector<int32_t> reduce_dims;
};
class Expand : public OpExprGradFunction<ExpandCaptureState> {
......@@ -33,39 +34,51 @@ class Expand : public OpExprGradFunction<ExpandCaptureState> {
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const ExpandCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Expand::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Expand::Capture(ExpandCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs[0]->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->logical_out_shape = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("logical_in_shape"));
ctx->logical_expand_shape =
JUST(composed_attrs.GetAttr<std::vector<int32_t>>("logical_expand_shape"));
const Shape& in_shape = *inputs[0]->shape();
const Shape& expand_shape = *outputs[0]->shape();
ctx->lpad = expand_shape.size() - in_shape.size();
ctx->keep_dims = (in_shape.size() > 0);
ctx->reduce_dims.reserve(expand_shape.size());
if (ctx->keep_dims) {
for (size_t i = 0; i < expand_shape.size(); ++i) {
const auto& t_dim = expand_shape[i];
const auto& dim = i < ctx->lpad ? 1 : in_shape[i - ctx->lpad];
if (dim != t_dim) { ctx->reduce_dims.push_back(i); }
}
} else {
for (int32_t axis = 0; axis < expand_shape.size(); ++axis) { ctx->reduce_dims.push_back(axis); }
}
return Maybe<void>::Ok();
}
Maybe<void> Expand::Apply(const ExpandCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("logical_out_shape", ctx->logical_out_shape));
JUST(attrs.SetAttr<std::vector<int32_t>>("logical_expand_shape", ctx->logical_expand_shape));
in_grads->at(0) = JUST(
functional::ExpandGrad(out_grads.at(0), ctx->logical_out_shape, ctx->logical_expand_shape));
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
in_grads->at(0) = out_grads[0];
if (ctx->reduce_dims.size() > 0) {
in_grads->at(0) =
JUST(functional::ReduceSum(in_grads->at(0), ctx->reduce_dims, ctx->keep_dims));
}
if (ctx->lpad > 0 && ctx->keep_dims) {
in_grads->at(0) = JUST(functional::Flatten(in_grads->at(0), 0, ctx->lpad));
}
return Maybe<void>::Ok();
}
......
......@@ -66,8 +66,8 @@ Maybe<void> Fold::Apply(const FoldInterpState* ctx, const TensorTuple& out_grads
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
in_grads->at(0) = JUST(functional::Unfold(out_grads.at(0), ctx->data_format, ctx->kernel_size,
ctx->dilation_rate, ctx->padding, ctx->strides));
in_grads->at(0) = JUST(functional::Unfold(out_grads.at(0), ctx->kernel_size, ctx->dilation_rate,
ctx->padding, ctx->strides, ctx->data_format));
return Maybe<void>::Ok();
}
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
struct FusedBiasAddScaleMaskSoftmaxDropoutCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool bias_requires_grad = false;
bool bias_broadcast = false;
int softmax_y_index = -1;
int bias_index = -1;
int mask_index = -1;
int dropout_mask_index = -1;
float scale = 1.0;
float dropout_scale = 1.0;
};
class FusedBiasAddScaleMaskSoftmaxDropoutGradFunction
: public OpExprGradFunction<FusedBiasAddScaleMaskSoftmaxDropoutCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(FusedBiasAddScaleMaskSoftmaxDropoutCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(outputs.size(), 2); // (y, softmax_y)
CHECK_EQ_OR_RETURN(inputs.size(), 4); // (x, bias, mask, dropout_mask)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->bias_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->x_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->scale = JUST(composed_attrs.GetAttr<float>("scale_value"));
ctx->dropout_scale = JUST(composed_attrs.GetAttr<float>("dropout_scale_value"));
if (ctx->x_requires_grad) {
ctx->mask_index = ctx->SaveTensorForBackward(inputs.at(2)); // mask
ctx->dropout_mask_index = ctx->SaveTensorForBackward(inputs.at(3)); // dropout_mask
ctx->softmax_y_index = ctx->SaveTensorForBackward(outputs.at(1)); // softmax_y
}
if (ctx->bias_requires_grad) {
ctx->bias_broadcast = (inputs.at(0)->shape() != inputs.at(1)->shape());
if (ctx->bias_broadcast) {
ctx->bias_index = ctx->SaveTensorForBackward(inputs.at(1)); // bias
}
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedBiasAddScaleMaskSoftmaxDropoutCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override {
if (!ctx->x_requires_grad && !ctx->bias_requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // (dy, d_softmax_y)
in_grads->resize(4); // (x, bias, mask, dropout_mask)
const auto& saved_tensors = ctx->SavedTensors();
const auto& dy = out_grads.at(0);
CHECK_GE_OR_RETURN(saved_tensors.size(), 3); // (mask, dropout_mask, softmax_y, [bias])
if (ctx->x_requires_grad || ctx->bias_requires_grad) {
const auto& mask = saved_tensors.at(ctx->mask_index);
const auto& dropout_mask = saved_tensors.at(ctx->dropout_mask_index);
const auto& softmax_y = saved_tensors.at(ctx->softmax_y_index);
in_grads->at(0) = JUST(functional::FusedScaleMaskSoftmaxDropoutGrad(
softmax_y, dy, mask, dropout_mask, ctx->scale, ctx->dropout_scale));
}
if (ctx->bias_requires_grad) {
if (ctx->bias_broadcast) {
const auto& bias = saved_tensors.at(ctx->bias_index);
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(in_grads->at(0), bias));
} else {
in_grads->at(1) = in_grads->at(0);
}
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_bias_add_scale_mask_softmax_dropout",
FusedBiasAddScaleMaskSoftmaxDropoutGradFunction);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
const int32_t INPUT_LEN = 8;
struct FusedCenterCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
};
class FusedCenterGrad : public OpExprGradFunction<FusedCenterCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(FusedCenterCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
for (int i = 0; i < INPUT_LEN; i++) {
ctx->requires_grad.push_back(inputs.at(i)->requires_grad());
}
for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedCenterCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const auto& rho2_diff = out_grads.at(0);
const auto& b1_x1 = ctx->SavedTensors().at(0);
const auto& b1_x2 = ctx->SavedTensors().at(1);
const auto& b2_x1 = ctx->SavedTensors().at(2);
const auto& b2_x2 = ctx->SavedTensors().at(3);
const auto& b1_y1 = ctx->SavedTensors().at(4);
const auto& b1_y2 = ctx->SavedTensors().at(5);
const auto& b2_y1 = ctx->SavedTensors().at(6);
const auto& b2_y2 = ctx->SavedTensors().at(7);
in_grads->resize(INPUT_LEN);
auto result = JUST(functional::FusedCenterGrad(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1,
b2_y2, rho2_diff));
CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);
for (int i = 0; i < INPUT_LEN; i++) {
if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_center_dist", FusedCenterGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedFastGeluMulGradCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
};
class FusedFastGeluMulGrad : public OpExprGradFunction<FusedFastGeluMulGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(FusedFastGeluMulGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // (in, multiplier)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // (out,)
ctx->requires_grad = inputs.at(0)->requires_grad() || inputs.at(1)->requires_grad();
if (ctx->requires_grad) {
ctx->SaveTensorForBackward(inputs.at(0)); // in
ctx->SaveTensorForBackward(inputs.at(1)); // multiplier
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedFastGeluMulGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const auto& out_diff = out_grads.at(0);
const auto& saved_tensors = ctx->SavedTensors();
CHECK_EQ_OR_RETURN(saved_tensors.size(), 2);
const auto& in = saved_tensors.at(0);
const auto& multiplier = saved_tensors.at(1);
in_grads->resize(2); // (in_diff, multiplier_diff)
auto result = JUST(functional::FusedFastGeluMulGrad(out_diff, in, multiplier));
CHECK_EQ_OR_RETURN(result->size(), 2);
in_grads->at(0) = result->at(0);
in_grads->at(1) = result->at(1);
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_fast_gelu_mul", FusedFastGeluMulGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <vector>
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
const int32_t INPUT_LEN = 8;
struct FusedGetBounddingBoxesCoordGradCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
};
class FusedGetBounddingBoxesCoordGrad
: public OpExprGradFunction<FusedGetBounddingBoxesCoordGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(FusedGetBounddingBoxesCoordGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);
CHECK_EQ_OR_RETURN(outputs.size(), INPUT_LEN);
for (int i = 0; i < INPUT_LEN; i++) {
ctx->requires_grad.push_back(inputs.at(i)->requires_grad());
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedGetBounddingBoxesCoordGradCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), INPUT_LEN);
const auto& b1_x1_diff = out_grads.at(0);
const auto& b1_x2_diff = out_grads.at(1);
const auto& b1_y1_diff = out_grads.at(2);
const auto& b1_y2_diff = out_grads.at(3);
const auto& b2_x1_diff = out_grads.at(4);
const auto& b2_x2_diff = out_grads.at(5);
const auto& b2_y1_diff = out_grads.at(6);
const auto& b2_y2_diff = out_grads.at(7);
in_grads->resize(8);
auto result = JUST(functional::FusedGetBounddingBoxesCoordGrad(
b1_x1_diff, b1_x2_diff, b1_y1_diff, b1_y2_diff, b2_x1_diff, b2_x2_diff, b2_y1_diff,
b2_y2_diff));
CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);
for (int i = 0; i < result->size(); i++) {
if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_boundding_boxes_coord", FusedGetBounddingBoxesCoordGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <vector>
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
const int32_t INPUT_LEN = 4;
struct FusedCiouAngleCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
float eps = 1e-8;
};
class FusedGetCiouDiagonalAngleGrad : public OpExprGradFunction<FusedCiouAngleCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(FusedCiouAngleCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
for (int i = 0; i < INPUT_LEN; i++) {
ctx->requires_grad.push_back(inputs.at(i)->requires_grad());
}
for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); }
ComposedAttrMap composed_attrs(attrs);
ctx->eps = JUST(composed_attrs.GetAttr<float>("eps"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedCiouAngleCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const auto& v_diff = out_grads.at(0);
const auto& w1 = ctx->SavedTensors().at(0);
const auto& h1 = ctx->SavedTensors().at(1);
const auto& w2 = ctx->SavedTensors().at(2);
const auto& h2 = ctx->SavedTensors().at(3);
auto result = JUST(functional::FusedGetCiouDiagonalAngleGrad(w1, h1, w2, h2, v_diff, ctx->eps));
CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);
in_grads->resize(INPUT_LEN);
for (int i = 0; i < INPUT_LEN; i++) {
if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_ciou_diagonal_angle", FusedGetCiouDiagonalAngleGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <vector>
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedGetCiouResultGradCaptureState : public AutoGradCaptureState {
bool v_requires_grad = false;
bool iou_requires_grad = false;
bool rho2_requires_grad = false;
bool c2_requires_grad = false;
};
class FusedGetCiouResultGrad : public OpExprGradFunction<FusedGetCiouResultGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(FusedGetCiouResultGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 4);
CHECK_EQ_OR_RETURN(outputs.size(), 2);
ctx->v_requires_grad = inputs.at(0)->requires_grad();
ctx->iou_requires_grad = inputs.at(1)->requires_grad();
ctx->rho2_requires_grad = inputs.at(2)->requires_grad();
ctx->c2_requires_grad = inputs.at(3)->requires_grad();
if (ctx->v_requires_grad && ctx->iou_requires_grad && ctx->rho2_requires_grad
&& ctx->c2_requires_grad) {
ctx->SaveTensorForBackward(outputs.at(1)); // alpha
ctx->SaveTensorForBackward(inputs.at(2)); // rho2
ctx->SaveTensorForBackward(inputs.at(3)); // c2
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedGetCiouResultGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 2);
const auto& dy = out_grads.at(0);
const auto& saved_tensors = ctx->SavedTensors();
CHECK_EQ_OR_RETURN(saved_tensors.size(), 3);
const auto& alpha = saved_tensors.at(0);
const auto& rho2 = saved_tensors.at(1);
const auto& c2 = saved_tensors.at(2);
in_grads->resize(4);
auto result = JUST(functional::FusedGetCiouResultGrad(dy, alpha, rho2, c2));
CHECK_EQ_OR_RETURN(result->size(), 4);
if (ctx->v_requires_grad && ctx->iou_requires_grad && ctx->rho2_requires_grad
&& ctx->c2_requires_grad) {
in_grads->at(0) = result->at(0);
in_grads->at(1) = result->at(1);
in_grads->at(2) = result->at(2);
in_grads->at(3) = result->at(3);
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_ciou_result", FusedGetCiouResultGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
const int32_t INPUT_LEN = 8;
struct FusedGetConvexDiagonalSquaredCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
float eps = 1e-8;
};
class FusedGetConvexDiagonalSquaredGrad
: public OpExprGradFunction<FusedGetConvexDiagonalSquaredCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(FusedGetConvexDiagonalSquaredCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
for (int i = 0; i < INPUT_LEN; i++) {
ctx->requires_grad.push_back(inputs.at(i)->requires_grad());
}
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->eps = JUST(composed_attrs.GetAttr<float>("eps"));
for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedGetConvexDiagonalSquaredCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const auto& c2_diff = out_grads.at(0);
const auto& b1_x1 = ctx->SavedTensors().at(0);
const auto& b1_x2 = ctx->SavedTensors().at(1);
const auto& b2_x1 = ctx->SavedTensors().at(2);
const auto& b2_x2 = ctx->SavedTensors().at(3);
const auto& b1_y1 = ctx->SavedTensors().at(4);
const auto& b1_y2 = ctx->SavedTensors().at(5);
const auto& b2_y1 = ctx->SavedTensors().at(6);
const auto& b2_y2 = ctx->SavedTensors().at(7);
in_grads->resize(INPUT_LEN);
auto result = JUST(functional::FusedGetConvexDiagonalSquaredGrad(
c2_diff, b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, ctx->eps));
CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);
for (int i = 0; i < INPUT_LEN; i++) {
if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_convex_diagonal_squared",
FusedGetConvexDiagonalSquaredGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
const int32_t INPUT_LEN = 8;
struct FusedGetIntersectionAreaCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
};
class FusedGetIntersectionAreaGrad
: public OpExprGradFunction<FusedGetIntersectionAreaCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(FusedGetIntersectionAreaCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), INPUT_LEN);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
for (int i = 0; i < INPUT_LEN; i++) {
ctx->requires_grad.push_back(inputs.at(i)->requires_grad());
}
for (int i = 0; i < INPUT_LEN; i++) { ctx->SaveTensorForBackward(inputs.at(i)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedGetIntersectionAreaCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const auto& rho2_diff = out_grads.at(0);
const auto& b1_x1 = ctx->SavedTensors().at(0);
const auto& b1_x2 = ctx->SavedTensors().at(1);
const auto& b2_x1 = ctx->SavedTensors().at(2);
const auto& b2_x2 = ctx->SavedTensors().at(3);
const auto& b1_y1 = ctx->SavedTensors().at(4);
const auto& b1_y2 = ctx->SavedTensors().at(5);
const auto& b2_y1 = ctx->SavedTensors().at(6);
const auto& b2_y2 = ctx->SavedTensors().at(7);
in_grads->resize(INPUT_LEN);
auto result = JUST(functional::FusedGetIntersectionAreaGrad(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1,
b1_y2, b2_y1, b2_y2, rho2_diff));
CHECK_EQ_OR_RETURN(result->size(), INPUT_LEN);
for (int i = 0; i < INPUT_LEN; i++) {
if (ctx->requires_grad[i]) { in_grads->at(i) = result->at(i); }
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_intersection_area", FusedGetIntersectionAreaGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <vector>
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/placed_nd_sbp.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct FusedGetIouGradCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
float eps = 1e-8;
};
class FusedGetIouGrad : public OpExprGradFunction<FusedGetIouGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(FusedGetIouGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 5);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad() && inputs.at(1)->requires_grad()
&& inputs.at(4)->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->eps = JUST(composed_attrs.GetAttr<float>("eps"));
if (ctx->requires_grad) {
ctx->SaveTensorForBackward(inputs.at(0)); // w1
ctx->SaveTensorForBackward(inputs.at(1)); // h1
ctx->SaveTensorForBackward(inputs.at(2)); // w2
ctx->SaveTensorForBackward(inputs.at(3)); // h2
ctx->SaveTensorForBackward(inputs.at(4)); // inter
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedGetIouGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const auto& diou = out_grads.at(0);
const auto& saved_tensors = ctx->SavedTensors();
CHECK_EQ_OR_RETURN(saved_tensors.size(), 5);
const auto& w1 = saved_tensors.at(0);
const auto& h1 = saved_tensors.at(1);
const auto& w2 = saved_tensors.at(2);
const auto& h2 = saved_tensors.at(3);
const auto& inter = saved_tensors.at(4);
in_grads->resize(5);
auto result = JUST(functional::FusedGetIouGrad(diou, w1, h1, w2, h2, inter, ctx->eps));
CHECK_EQ_OR_RETURN(result->size(), 3);
if (ctx->requires_grad) {
in_grads->at(0) = result->at(0);
in_grads->at(1) = result->at(1);
in_grads->at(4) = result->at(2);
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_get_iou", FusedGetIouGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct FusedMatmulBiasCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool weight_requires_grad = false;
bool bias_requires_grad = false;
};
class FusedMatmulBias : public OpExprGradFunction<FusedMatmulBiasCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(FusedMatmulBiasCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const FusedMatmulBiasCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
protected:
AttrMap base_attrs_;
};
Maybe<void> FusedMatmulBias::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> FusedMatmulBias::Capture(FusedMatmulBiasCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_GE_OR_RETURN(inputs.size(), 3)
<< "x, weight, and bias, [add_to_output] should all be included";
ctx->x_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
ctx->weight_requires_grad = JUST(VectorAt(inputs, 1))->requires_grad();
ctx->bias_requires_grad = JUST(VectorAt(inputs, 2))->requires_grad();
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));
return Maybe<void>::Ok();
}
Maybe<void> FusedMatmulBias::Apply(const FusedMatmulBiasCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "FusedMatmulBias more than one output";
const auto& x = ctx->SavedTensors().at(0);
const auto& weight = ctx->SavedTensors().at(1);
if (ctx->x_requires_grad) {
in_grads->at(0) =
JUST(functional::MatMul(JUST(VectorAt(out_grads, 0)), weight, false, false, 1.0));
}
if (ctx->weight_requires_grad) {
in_grads->at(1) = JUST(functional::BroadcastMatmulGradB(JUST(VectorAt(out_grads, 0)), x, 1.0));
}
if (ctx->bias_requires_grad) {
const int64_t num_axes = out_grads.at(0)->shape()->NumAxes();
std::vector<int32_t> reduce_axes_vec;
reduce_axes_vec.reserve(num_axes - 1);
for (int i = 0; i < num_axes - 1; i++) { reduce_axes_vec.push_back(i); }
in_grads->at(2) =
JUST(functional::ReduceSum(JUST(VectorAt(out_grads, 0)), reduce_axes_vec, false));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_matmul_bias", FusedMatmulBias);
} // namespace one
} // namespace oneflow
......@@ -84,7 +84,7 @@ Maybe<void> FusedMatmulBiasAddReluDropout::Capture(FusedMatmulBiasAddReluDropout
ctx->SaveTensorForBackward(
JUST(VectorAt(outputs, i + 1))); // cublas aux. need minus 1. idx_sum:2+2w
}
for (int32_t i = 0; i < weight_num - 1; i++) {
for (int32_t i = 0; i < weight_num; i++) {
ctx->SaveTensorForBackward(JUST(VectorAt(outputs, i + 1 + weight_num))); // hidden.
}
......@@ -101,7 +101,7 @@ Maybe<void> FusedMatmulBiasAddReluDropout::Apply(
int32_t weight_num = ctx->weight_num;
in_grads->resize(1 + 2 * weight_num);
TensorTuple hiddens(weight_num - 1);
TensorTuple hiddens(weight_num);
TensorTuple weights(weight_num);
TensorTuple cublas_auxs(weight_num);
TensorTuple dgrad(weight_num);
......@@ -117,9 +117,10 @@ Maybe<void> FusedMatmulBiasAddReluDropout::Apply(
cublas_auxs[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + weight_num));
}
for (int32_t i = 0; i < weight_num - 1; ++i) {
for (int32_t i = 0; i < weight_num; ++i) {
hiddens[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + 2 * weight_num));
}
float rate = ctx->dropout_rate_list.at(weight_num - 1);
float scale = 0.0f;
if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); }
......@@ -136,6 +137,36 @@ Maybe<void> FusedMatmulBiasAddReluDropout::Apply(
cublas_auxs[weight_num - 1], scale));
}
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD", false)) {
std::vector<float> alpha_list(weight_num - 1, 1.0);
for (int i = 0; i < weight_num - 1; i++) {
rate = ctx->dropout_rate_list.at(i);
scale = 1.0;
if (rate < 1.0f) { scale = 1.0f / (1.0f - rate); }
alpha_list.at(i) = scale;
}
const auto& fused_mlp_grad =
JUST(functional::FusedMLPGrad(last_bias_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), weights,
cublas_auxs, hiddens, alpha_list));
if (ctx->x_requires_grad) {
// dx:
JUST(VectorAt(*in_grads, 0)) = fused_mlp_grad->at(0);
}
for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > -1; hidden_layer_idx--) {
if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx)))) {
// dbias
JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx + 1)) =
fused_mlp_grad->at(1 + hidden_layer_idx); // NOLINT
}
// dw
if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {
JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) =
fused_mlp_grad->at(1 + weight_num + hidden_layer_idx);
}
}
} else {
// step2: use reduce_sum to get last layer's bias grad.
std::vector<int32_t> reduce_axes_vec{0};
if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) {
......@@ -190,8 +221,9 @@ Maybe<void> FusedMatmulBiasAddReluDropout::Apply(
}
if (JUST(VectorAt(ctx->weights_requires_grad, 0))) {
// dw:
JUST(VectorAt(*in_grads, 1)) =
JUST(functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0));
JUST(VectorAt(*in_grads, 1)) = JUST(
functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0));
}
}
return Maybe<void>::Ok();
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
struct FusedWeightedSumCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
std::vector<float> weights;
float alpha{};
};
class FusedWeightedSum : public OpExprGradFunction<FusedWeightedSumCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(FusedWeightedSumCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->requires_grad.resize(inputs.size());
ctx->weights = JUST(attrs.GetAttr<std::vector<float>>("weights"));
ctx->alpha = JUST(attrs.GetAttr<float>("alpha"));
CHECK_EQ_OR_RETURN(ctx->weights.size(), inputs.size());
for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs[i]->requires_grad(); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const FusedWeightedSumCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(ctx->requires_grad.size());
for (int i = 0; i < ctx->requires_grad.size(); ++i) {
if (ctx->requires_grad[i]) {
(*in_grads)[i] =
JUST(functional::ScalarMul(out_grads[0], ctx->weights[i] * ctx->alpha, false));
}
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("fused_weighted_sum", FusedWeightedSum);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/mutable_attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct CastGlobalCaptureState : public AutoGradCaptureState {
Symbol<ParallelDesc> parallel_desc;
Symbol<NdSbp> nd_sbp;
std::shared_ptr<const Shape> shape;
Symbol<DType> dtype;
};
class LocalToGlobal : public OpExprGradFunction<CastGlobalCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const LocalToGlobalOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
const std::string& op_name = fw_op_expr->op_name();
grad_op_ = JUST(one::GlobalToLocalOpExpr::New(GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> Capture(CastGlobalCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
ctx->parallel_desc = JUST(interp_ctx.parallel_desc);
ctx->nd_sbp = JUST(GetDualNdSbp(JUST(interp_ctx.nd_sbp)));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CastGlobalCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
std::shared_ptr<Tensor> out_grad = out_grads.at(0);
CHECK_OR_RETURN(out_grad->is_global())
<< Error::RuntimeError()
<< "Expected global tensor for local_to_global but got local tensor";
{
Symbol<NdSbp> nd_sbp_constraint = ctx->nd_sbp;
Symbol<ParallelDesc> parallel_desc_constraint = ctx->parallel_desc;
out_grad = JUST(functional::ToGlobal(out_grad, parallel_desc_constraint,
*JUST(GetSbpList(nd_sbp_constraint)), GetNoneSbpList(),
/* check_meta */ false, /*copy=*/false));
}
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grad}));
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("local_to_global", LocalToGlobal);
class GlobalToLocal : public OpExprGradFunction<CastGlobalCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const GlobalToLocalOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
const std::string& op_name = fw_op_expr->op_name();
grad_op_ = JUST(one::LocalToGlobalOpExpr::New(GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> Capture(CastGlobalCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
const auto& input = inputs.at(0);
CHECK_OR_RETURN(input->is_global())
<< Error::RuntimeError()
<< "Expected global tensor for global_to_local but got local tensor";
ctx->parallel_desc = JUST(input->parallel_desc());
ctx->nd_sbp = JUST(input->nd_sbp());
ctx->shape = input->shape();
ctx->dtype = input->dtype();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CastGlobalCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& dual_nd_sbp = JUST(GetDualNdSbp(ctx->nd_sbp));
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape", "dtype", "sync_data");
attrs.SetAllAttrs(*ctx->shape, ctx->dtype->data_type(), true);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(
*grad_op_, {out_grads.at(0)}, OpExprInterpContext(attrs, ctx->parallel_desc, dual_nd_sbp)));
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("global_to_local", GlobalToLocal);
} // namespace one
} // namespace oneflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment