"profiler/vscode:/vscode.git/clone" did not exist on "31ea132aa21ec37fe11735b2e0e71041b95f1911"
Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <functional>
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct DivGradGradCaptureState : public AutoGradCaptureState {
bool y_requires_grad = false;
bool z_requires_grad = false;
bool grad_requires_grad = false;
size_t y_index = 0;
size_t z_index = 1;
size_t grad_index = 2;
};
class DivGradGrad : public OpExprGradFunction<DivGradGradCaptureState> {
// div_grad = -x/(y*y)*dz = -z/y*dz
// div_grad_y = out_grad * z*dz/(y*y)
// div_grad_z = out_grad * -dz/y
// div_grad_dz = out_grad * -z/y
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(DivGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// dz, z, y
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs.at(0)->requires_grad();
ctx->z_requires_grad = inputs.at(1)->requires_grad();
ctx->y_requires_grad = inputs.at(2)->requires_grad();
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(2));
if (ctx->y_requires_grad || ctx->grad_requires_grad) {
ctx->z_index = ctx->SaveTensorForBackward(inputs.at(1));
}
if (ctx->y_requires_grad || ctx->z_requires_grad) {
ctx->grad_index = ctx->SaveTensorForBackward(inputs.at(0));
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const DivGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(3);
const auto& y = ctx->SavedTensors().at(ctx->y_index);
if (ctx->grad_requires_grad) {
const auto& z = ctx->SavedTensors().at(ctx->z_index);
in_grads->at(0) = JUST(functional::sequence_function(functional::Mul)
.then(functional::Negative)
.then(std::bind(functional::Div, std::placeholders::_1, y))
.call(out_grads.at(0), z));
}
if (ctx->z_requires_grad) {
const auto& grad = ctx->SavedTensors().at(ctx->grad_index);
in_grads->at(1) = JUST(functional::sequence_function(functional::Mul)
.then(functional::Negative)
.then(std::bind(functional::Div, std::placeholders::_1, y))
.call(out_grads.at(0), grad));
}
if (ctx->y_requires_grad) {
const auto& z = ctx->SavedTensors().at(ctx->z_index);
const auto& grad = ctx->SavedTensors().at(ctx->grad_index);
in_grads->at(2) = JUST(
functional::sequence_function(functional::Mul)
.then(std::bind(functional::BroadcastReduceSumLike, std::placeholders::_1, y))
.then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0)))
.then(std::bind(functional::Div, std::placeholders::_1, JUST(functional::Square(y))))
.call(z, grad));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_div_grad", DivGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct KLDivLossGradGradCaptureState : public AutoGradCaptureState {
bool grad_requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
bool log_target = false;
size_t input_index = 0;
size_t target_index = 0;
};
class KLDivLossGradGrad : public OpExprGradFunction<KLDivLossGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(KLDivLossGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const KLDivLossGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> KLDivLossGradGrad::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> KLDivLossGradGrad::Capture(KLDivLossGradGradCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
// grad, input, target
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs[0]->requires_grad();
ctx->input_requires_grad = inputs[1]->requires_grad();
ctx->target_requires_grad = inputs[2]->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->log_target = JUST(composed_attrs.GetAttr<bool>("log_target"));
if (ctx->grad_requires_grad) {
ctx->input_index = ctx->SaveTensorForBackward(inputs[1]); // input
ctx->target_index = ctx->SaveTensorForBackward(inputs[2]); // target
}
return Maybe<void>::Ok();
}
Maybe<void> KLDivLossGradGrad::Apply(const KLDivLossGradGradCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(3);
if (ctx->grad_requires_grad) {
const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index));
const auto& target = JUST(VectorAt(ctx->SavedTensors(), ctx->target_index));
(*in_grads)[0] = JUST(functional::KLDivLossGrad(out_grads[0], input, target, ctx->log_target));
}
if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); }
if (ctx->target_requires_grad) { (*in_grads)[2] = JUST(functional::ZerosLike(out_grads[0])); }
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("kl_div_loss_grad", KLDivLossGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct LogSoftmaxGradGradCaptureState : public AutoGradCaptureState {
bool y_requires_grad = false;
bool dy_requires_grad = false;
};
class LogSoftmaxGradGrad : public OpExprGradFunction<LogSoftmaxGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(LogSoftmaxGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const LogSoftmaxGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> LogSoftmaxGradGrad::Init(const OpExpr& op) { return Maybe<void>::Ok(); }
Maybe<void> LogSoftmaxGradGrad::Capture(LogSoftmaxGradGradCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
// y, dy
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->y_requires_grad = inputs[0]->requires_grad();
ctx->dy_requires_grad = inputs[1]->requires_grad();
ctx->SaveTensorForBackward(inputs[0]);
if (ctx->y_requires_grad) ctx->SaveTensorForBackward(inputs[1]);
return Maybe<void>::Ok();
}
Maybe<void> LogSoftmaxGradGrad::Apply(const LogSoftmaxGradGradCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
in_grads->resize(2);
const auto& y = ctx->SavedTensors()[0];
const std::vector<int32_t> reduce_axis{static_cast<int32_t>(y->ndim() - 1)};
if (ctx->y_requires_grad) {
const auto& dy = ctx->SavedTensors()[1];
in_grads->at(0) =
JUST(functional::sequence_function(functional::ReduceSum)
.then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))
.then(std::bind(functional::Mul, std::placeholders::_1, JUST(functional::Exp(y))))
.then(functional::Negative)
.call(dy, reduce_axis, true));
}
if (ctx->dy_requires_grad) {
in_grads->at(1) =
JUST(functional::sequence_function(functional::Exp)
.then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))
.then(std::bind(functional::ReduceSum, std::placeholders::_1, reduce_axis,
/*keepdim=*/true))
.then(std::bind(functional::Sub, out_grads[0], std::placeholders::_1, /*alpha=*/1,
/*inplace=*/false))
.call(y));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("log_softmax_grad", LogSoftmaxGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct UnaryMathGradGradState : public AutoGradCaptureState {
bool input_requires_grad = false;
bool grad_requires_grad = false;
};
typedef Maybe<one::Tensor> (*UnaryBwFunc)(const std::shared_ptr<one::Tensor>&,
const std::shared_ptr<one::Tensor>&);
template<UnaryBwFunc BwFunc, UnaryBwFunc BwBwFunc>
class UnaryMathGradGrad : public OpExprGradFunction<UnaryMathGradGradState> {
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UnaryMathGradGradState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs[0]->requires_grad();
ctx->grad_requires_grad = inputs[1]->requires_grad();
ctx->SaveTensorForBackward(inputs[0]);
if (ctx->input_requires_grad) { ctx->SaveTensorForBackward(inputs[1]); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UnaryMathGradGradState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
const auto& input = ctx->SavedTensors()[0];
if (ctx->input_requires_grad) {
const auto& grad = ctx->SavedTensors()[1];
(*in_grads)[0] = JUST(functional::Mul(out_grads[0], JUST(BwBwFunc(input, grad))));
}
if (ctx->grad_requires_grad) { (*in_grads)[1] = JUST(BwFunc(input, out_grads[0])); }
return Maybe<void>::Ok();
}
};
template<UnaryBwFunc BwFunc>
class UnaryMathGradGradWithZeroDDX : public OpExprGradFunction<UnaryMathGradGradState> {
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(UnaryMathGradGradState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs[0]->requires_grad();
ctx->grad_requires_grad = inputs[1]->requires_grad();
ctx->SaveTensorForBackward(inputs[0]);
return Maybe<void>::Ok();
}
Maybe<void> Apply(const UnaryMathGradGradState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
const auto& input = ctx->SavedTensors()[0];
if (ctx->input_requires_grad) { (*in_grads)[0] = JUST(functional::ZerosLike(input)); }
if (ctx->grad_requires_grad) { (*in_grads)[1] = JUST(BwFunc(input, out_grads[0])); }
return Maybe<void>::Ok();
}
};
// TODO: Lgamma, first order backward unimplemented
#define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_X_FUNC_SEQ \
OF_PP_MAKE_TUPLE_SEQ("sin_grad", Sin) \
OF_PP_MAKE_TUPLE_SEQ("cos_grad", Cos) \
OF_PP_MAKE_TUPLE_SEQ("tan_grad", Tan) \
OF_PP_MAKE_TUPLE_SEQ("sinh_grad", Sinh) \
OF_PP_MAKE_TUPLE_SEQ("cosh_grad", Cosh) \
OF_PP_MAKE_TUPLE_SEQ("tanh_grad", Tanh) \
OF_PP_MAKE_TUPLE_SEQ("asin_grad", Asin) \
OF_PP_MAKE_TUPLE_SEQ("acos_grad", Acos) \
OF_PP_MAKE_TUPLE_SEQ("atan_grad", Atan) \
OF_PP_MAKE_TUPLE_SEQ("asinh_grad", Asinh) \
OF_PP_MAKE_TUPLE_SEQ("acosh_grad", Acosh) \
OF_PP_MAKE_TUPLE_SEQ("atanh_grad", Atanh) \
OF_PP_MAKE_TUPLE_SEQ("erf_grad", Erf) \
OF_PP_MAKE_TUPLE_SEQ("erfc_grad", Erfc) \
OF_PP_MAKE_TUPLE_SEQ("exp_grad", Exp) \
OF_PP_MAKE_TUPLE_SEQ("expm1_grad", Expm1) \
OF_PP_MAKE_TUPLE_SEQ("log_grad", Log) \
OF_PP_MAKE_TUPLE_SEQ("log_sigmoid_grad", LogSigmoid) \
OF_PP_MAKE_TUPLE_SEQ("log2_grad", Log2) \
OF_PP_MAKE_TUPLE_SEQ("log1p_grad", Log1p) \
OF_PP_MAKE_TUPLE_SEQ("reciprocal_grad", Reciprocal) \
OF_PP_MAKE_TUPLE_SEQ("reciprocal_no_nan_grad", ReciprocalNoNan) \
OF_PP_MAKE_TUPLE_SEQ("rsqrt_grad", Rsqrt) \
OF_PP_MAKE_TUPLE_SEQ("sqrt_grad", Sqrt) \
OF_PP_MAKE_TUPLE_SEQ("square_grad", Square)
#define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_Y_FUNC_SEQ OF_PP_MAKE_TUPLE_SEQ("sigmoid_grad", Sigmoid)
#define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_ZERO_DDX_FUNC_SEQ OF_PP_MAKE_TUPLE_SEQ("abs_grad", Abs)
#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS(op_type_name, op_cls) \
class op_cls##GradGradCls final \
: public UnaryMathGradGrad<functional::op_cls##Grad, functional::op_cls##GradGrad> {}; \
REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls);
OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS,
MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_X_FUNC_SEQ);
OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS,
MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_Y_FUNC_SEQ);
#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_ZERO_DDX_CLASS(op_type_name, op_cls) \
class op_cls##GradGradCls final \
: public UnaryMathGradGradWithZeroDDX<functional::op_cls##Grad> {}; \
REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls);
OF_PP_FOR_EACH_TUPLE(INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_ZERO_DDX_CLASS,
MATH_UNARY_ELEMENTWISE_GRAD_GRAD_ZERO_DDX_FUNC_SEQ);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct BroadcastMatmulGradBGradCaptureState : public AutoGradCaptureState {
bool a_requires_grad = false;
bool b_requires_grad = false;
size_t a_index = 0;
size_t b_index = 1;
double alpha = 1.0;
};
class BroadcastMatmulGradBGrad : public OpExprGradFunction<BroadcastMatmulGradBGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
};
Maybe<void> Capture(BroadcastMatmulGradBGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->a_requires_grad = inputs.at(0)->requires_grad();
ctx->b_requires_grad = inputs.at(1)->requires_grad();
if (ctx->a_requires_grad) { ctx->b_index = ctx->SaveTensorForBackward(inputs.at(1)); }
if (ctx->b_requires_grad) { ctx->a_index = ctx->SaveTensorForBackward(inputs.at(0)); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->alpha = JUST(composed_attrs.GetAttr<double>("alpha"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const BroadcastMatmulGradBGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);
// for matmul: input_a[dims..., m, k] * input_b[k, n] -> [dims..., m, n]
// if forward: BroadcastMatmulGradB(input_a, JUST(VectorAt(out_grads, 0)), ctx->alpha))
// then: a.shape = [dims..., m, k], b.shape = [dims..., m, n], grad.shape = [k, n]
// if forward: BroadcastMatmulGradB(JUST(VectorAt(out_grads, 0)), input_a, ctx->alpha))
// then: a.shape = [dims..., m, n], b.shape = [dims..., m, k], grad.shape = [n, k]
if (ctx->a_requires_grad) {
const auto& b = ctx->SavedTensors()[ctx->b_index];
in_grads->at(0) = JUST(functional::MatMul(b, out_grads.at(0), false, true, ctx->alpha));
}
if (ctx->b_requires_grad) {
const auto& a = ctx->SavedTensors()[ctx->a_index];
in_grads->at(1) = JUST(functional::MatMul(a, out_grads.at(0), false, false, ctx->alpha));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_matmul_grad_b", BroadcastMatmulGradBGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct MaxPoolGradGradCaptureState : public AutoGradCaptureState {
bool grad_requires_grad = false;
bool input_requires_grad = false;
};
template<int ndims>
class MaxPoolNdGradGrad : public OpExprGradFunction<MaxPoolGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(MaxPoolGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// dy, x, indice
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs[0]->requires_grad();
ctx->input_requires_grad = inputs[1]->requires_grad();
if (ctx->grad_requires_grad) { ctx->SaveTensorForBackward(inputs[2]); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const MaxPoolGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(3);
if (ctx->grad_requires_grad) {
const auto& indices = JUST(VectorAt(ctx->SavedTensors(), 0));
(*in_grads)[0] = JUST(functional::MaxPoolNdGradGrad(out_grads[0], indices, ndims));
}
if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); }
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("max_pool_1d_grad", MaxPoolNdGradGrad<1>);
REGISTER_OP_EXPR_GRAD_FUNCTION("max_pool_2d_grad", MaxPoolNdGradGrad<2>);
REGISTER_OP_EXPR_GRAD_FUNCTION("max_pool_3d_grad", MaxPoolNdGradGrad<3>);
// REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool1d_grad", MaxPoolNdGradGrad<1>);
// REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool2d_grad", MaxPoolNdGradGrad<2>);
// REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool3d_grad", MaxPoolNdGradGrad<3>);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace one {
struct NLLCaptureState : public AutoGradCaptureState {
bool input_requires_grad = false;
bool grad_requires_grad = false;
bool has_weight = false;
int64_t ignore_index = -100;
};
class NLLLossGradGrad : public OpExprGradFunction<NLLCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(NLLCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> NLLLossGradGrad::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> NLLLossGradGrad::Capture(NLLCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
// dy, input, target[, weight]
CHECK_OR_RETURN(inputs.size() >= 3 && inputs.size() <= 4); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs[0]->requires_grad();
ctx->input_requires_grad = inputs[1]->requires_grad();
ctx->has_weight = inputs.size() == 4;
if (ctx->grad_requires_grad) {
ctx->SaveTensorForBackward(inputs[2]);
if (ctx->has_weight) { ctx->SaveTensorForBackward(inputs[3]); } // weight
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->ignore_index = JUST(composed_attrs.GetAttr<int64_t>("ignore_index"));
}
return Maybe<void>::Ok();
}
Maybe<void> NLLLossGradGrad::Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(3 + ctx->has_weight);
if (ctx->grad_requires_grad) {
const auto& target = JUST(VectorAt(ctx->SavedTensors(), 0));
if (ctx->has_weight) {
auto weight = JUST(VectorAt(ctx->SavedTensors(), 1));
(*in_grads)[0] =
JUST(functional::NLLLoss(out_grads[0], target, weight, ctx->ignore_index, "none"));
} else {
(*in_grads)[0] =
JUST(functional::NLLLoss(out_grads[0], target, NullOpt, ctx->ignore_index, "none"));
}
}
if (ctx->input_requires_grad) { (*in_grads)[1] = JUST(functional::ZerosLike(out_grads[0])); }
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("nll_grad", NLLLossGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <functional>
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct PowXGradGradCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool y_requires_grad = false;
bool dz_requires_grad = false;
size_t x_index = 0;
size_t y_index = 1;
size_t dz_index = 2;
};
class PowXGradGrad : public OpExprGradFunction<PowXGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(PowXGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// x, y, dz
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->y_requires_grad = inputs.at(1)->requires_grad();
ctx->dz_requires_grad = inputs.at(2)->requires_grad();
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
if (ctx->x_requires_grad || ctx->y_requires_grad) {
ctx->dz_index = ctx->SaveTensorForBackward(inputs.at(2));
}
return Maybe<void>::Ok();
}
Maybe<void> Apply(const PowXGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(3);
const auto& x = ctx->SavedTensors().at(ctx->x_index);
const auto& y = ctx->SavedTensors().at(ctx->y_index);
// dx = y * x^(y-1) * dz
// grad_for_x = out_grads * dz * y * [x^(y-1)]'
// grad_for_y = out_grads * dz * [x^(y-1) * (1 + y * ln(x))]
// grad_for_dz = out_grads * y * x^(y-1)
if (ctx->x_requires_grad || ctx->y_requires_grad) {
const auto& dz = ctx->SavedTensors().at(ctx->dz_index);
const auto& y_sub_one = JUST(functional::ScalarSub(y, 1, /*alpha=*/1, /*inplace=*/false));
if (ctx->x_requires_grad) {
in_grads->at(0) = JUST(functional::sequence_function(functional::PowXGrad)
.then(std::bind(functional::Mul, std::placeholders::_1, y))
.then(std::bind(functional::Mul, std::placeholders::_1, dz))
.call(x, y_sub_one, out_grads.at(0)));
}
if (ctx->y_requires_grad) {
in_grads->at(1) =
JUST(functional::sequence_function(functional::Log)
.then(std::bind(functional::Mul, std::placeholders::_1, y))
.then([](const std::shared_ptr<Tensor>& input) {
return functional::ScalarAdd(1, input, /*alpha=*/1);
})
.then(std::bind(functional::Mul, std::placeholders::_1,
JUST(functional::Pow(x, y_sub_one))))
.then(std::bind(functional::Mul, std::placeholders::_1, dz))
.then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0)))
.call(x));
}
}
if (ctx->dz_requires_grad) {
in_grads->at(2) = JUST(functional::PowXGrad(x, y, out_grads.at(0)));
}
return Maybe<void>::Ok();
}
};
struct PowYGradGradCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool y_requires_grad = false;
bool dz_requires_grad = false;
size_t x_index = 0;
size_t y_index = 1;
size_t dz_index = 2;
size_t dy_index = 3;
};
class PowYGradGrad : public OpExprGradFunction<PowYGradGradCaptureState> {
public:
// dy = x^y*ln(x)*dz = z*ln(x)*dz
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(PowYGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// x, y, dz
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->y_requires_grad = inputs.at(1)->requires_grad();
ctx->dz_requires_grad = inputs.at(2)->requires_grad();
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
if (ctx->x_requires_grad || ctx->y_requires_grad) {
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
}
if (ctx->x_requires_grad) { ctx->dz_index = ctx->SaveTensorForBackward(inputs.at(2)); }
if (ctx->y_requires_grad) { ctx->dy_index = ctx->SaveTensorForBackward(outputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const PowYGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(3);
const auto& x = ctx->SavedTensors().at(ctx->x_index);
// dy = x^y * ln(x) * dz = z * ln(x) * dz
// grad_for_x = out_grads * dz * [x^(y-1) * (1 + y * ln(x))]
// grad_for_y = out_grads * dy' = out_grads * dy * ln(x)
// grad_for_dz = out_grads * x^y * ln(x)
if (ctx->x_requires_grad) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& dz = ctx->SavedTensors().at(ctx->dz_index);
const auto& y_sub_one = JUST(functional::ScalarSub(y, 1, /*alpha=*/1, /*inplace=*/false));
in_grads->at(0) =
JUST(functional::sequence_function(functional::Log)
.then(std::bind(functional::Mul, std::placeholders::_1, y))
.then([](const std::shared_ptr<Tensor>& input) {
return functional::ScalarAdd(1, input, /*alpha=*/1);
})
.then(std::bind(functional::Mul, std::placeholders::_1,
JUST(functional::Pow(x, y_sub_one))))
.then(std::bind(functional::Mul, std::placeholders::_1, dz))
.then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0)))
.call(x));
}
if (ctx->y_requires_grad) {
const auto& dy = ctx->SavedTensors().at(ctx->dy_index);
in_grads->at(1) =
JUST(functional::sequence_function(functional::Log)
.then(std::bind(functional::Mul, std::placeholders::_1, dy))
.then(std::bind(functional::Mul, std::placeholders::_1, out_grads.at(0)))
.call(x));
}
if (ctx->dz_requires_grad) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
in_grads->at(2) = JUST(functional::PowYGrad(x, y, out_grads.at(0)));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("pow_x_grad", PowXGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("pow_y_grad", PowYGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct ScalarPowGradGradCaptureState : public AutoGradCaptureState {
bool x_requires_grad = false;
bool grad_requires_grad = false;
Scalar operand;
};
class ScalarPowGradGrad : public OpExprGradFunction<ScalarPowGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Capture(ScalarPowGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->grad_requires_grad = inputs.at(1)->requires_grad();
if (!(ctx->x_requires_grad || ctx->grad_requires_grad)) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
bool has_float_operand = JUST(composed_attrs.GetAttr<bool>("has_float_operand"));
if (has_float_operand) {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>("float_operand")));
} else {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>("int_operand")));
}
ctx->SaveTensorForBackward(inputs.at(0));
if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarPowGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
in_grads->resize(2);
// z = x^a, dx = a * x^(a-1) * dz
// grad_for_x = out_grad * a * dz * [x^(a-1)]'
// grad_for_dz = out_grad * [x^a]'
if (ctx->x_requires_grad) {
const auto& grad = ctx->SavedTensors().at(1);
const auto operand_sub_one = ctx->operand - Scalar(1);
in_grads->at(0) = JUST(
functional::sequence_function(functional::Mul)
.then(std::bind(functional::ScalarPowGrad, x, std::placeholders::_1, operand_sub_one))
.then([&ctx](const std::shared_ptr<Tensor>& input) {
return functional::ScalarMul(ctx->operand, input);
})
.call(grad, out_grads.at(0)));
}
if (ctx->grad_requires_grad) {
in_grads->at(1) = JUST(functional::ScalarPowGrad(x, out_grads.at(0), ctx->operand));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
class ScalarReversePowGradGrad : public OpExprGradFunction<ScalarPowGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Capture(ScalarPowGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->grad_requires_grad = inputs.at(1)->requires_grad();
if (!(ctx->x_requires_grad || ctx->grad_requires_grad)) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
bool has_float_operand = JUST(composed_attrs.GetAttr<bool>("has_float_operand"));
if (has_float_operand) {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<double>("float_operand")));
} else {
ctx->operand = Scalar(JUST(composed_attrs.GetAttr<int64_t>("int_operand")));
}
ctx->SaveTensorForBackward(inputs.at(0));
if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); }
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ScalarPowGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
in_grads->resize(2);
// z = a^x, dx = a^x * ln(a) * dz
// grad_for_x = out_grad * dz * a^x * ln(a) * ln(a)
// grad_for_dz = out_grad * [a^x]'
if (ctx->x_requires_grad) {
const auto& dx = ctx->SavedTensors().at(1);
const auto log_operand = std::log(ctx->operand.As<double>());
in_grads->at(0) = JUST(functional::sequence_function(functional::Mul)
.then([&log_operand](const std::shared_ptr<Tensor>& input) {
return functional::ScalarMul(log_operand, input);
})
.call(dx, out_grads.at(0)));
}
if (ctx->grad_requires_grad) {
in_grads->at(1) = JUST(functional::ScalarReversePowGrad(x, out_grads.at(0), ctx->operand));
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_pow_grad", ScalarPowGradGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_reverse_pow_grad", ScalarReversePowGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct SliceGradGradCaptureState : public AutoGradCaptureState {
std::vector<int64_t> start;
std::vector<int64_t> stop;
std::vector<int64_t> step;
};
class SliceGradGrad : public OpExprGradFunction<SliceGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(SliceGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->start = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("start"));
ctx->stop = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("stop"));
ctx->step = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("step"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SliceGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(1);
in_grads->at(0) = JUST(functional::Slice(out_grads.at(0), ctx->start, ctx->stop, ctx->step,
/*enable_view_slice=*/false));
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("slice_grad", SliceGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct SmoothL1LossGradGradCaptureState : public AutoGradCaptureState {
bool grad_requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
size_t grad_index = 0;
size_t input_index = 0;
size_t target_index = 0;
float beta = 0.0;
};
class SmoothL1LossGradGrad : public OpExprGradFunction<SmoothL1LossGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Capture(SmoothL1LossGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
// grad, input, target
CHECK_EQ_OR_RETURN(inputs.size(), 3); // NOLINT(maybe-need-error-msg)
ctx->grad_requires_grad = inputs[0]->requires_grad();
ctx->input_requires_grad = inputs[1]->requires_grad();
ctx->target_requires_grad = inputs[2]->requires_grad();
if (ctx->input_requires_grad || ctx->target_requires_grad) {
ctx->grad_index = ctx->SaveTensorForBackward(inputs[0]);
}
ctx->input_index = ctx->SaveTensorForBackward(inputs[1]);
ctx->target_index = ctx->SaveTensorForBackward(inputs[2]);
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->beta = JUST(composed_attrs.GetAttr<float>("beta"));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const SmoothL1LossGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(3);
const auto& input = JUST(VectorAt(ctx->SavedTensors(), ctx->input_index));
const auto& target = JUST(VectorAt(ctx->SavedTensors(), ctx->target_index));
if (ctx->grad_requires_grad) {
(*in_grads)[0] = JUST(functional::SmoothL1LossGrad(out_grads[0], input, target, ctx->beta));
}
if (ctx->input_requires_grad || ctx->target_requires_grad) {
const auto& grad = JUST(VectorAt(ctx->SavedTensors(), ctx->grad_index));
auto condition = JUST(functional::sequence_function(functional::Sub)
.then(functional::Abs)
.then([&ctx](const std::shared_ptr<Tensor>& input) {
return functional::ScalarLogicalLess(input, ctx->beta);
})
.call(input, target, /*alpha=*/1, /*inplace=*/false));
auto out = JUST(functional::sequence_function(functional::Mul)
.then(std::bind(functional::Mul, std::placeholders::_1, condition))
.then([&ctx](const std::shared_ptr<Tensor>& input) {
double inv_beta = ctx->beta == 0.0 ? 0.0 : 1.0 / ctx->beta;
return functional::ScalarMul(inv_beta, input);
})
.call(out_grads[0], grad));
if (ctx->input_requires_grad) { (*in_grads)[1] = out; }
if (ctx->target_requires_grad) { (*in_grads)[2] = JUST(functional::Negative(out)); }
}
return Maybe<void>::Ok();
}
private:
AttrMap base_attrs_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("smooth_l1_loss_grad", SmoothL1LossGradGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace oneflow {
namespace one {
struct SoftmaxGradGradCaptureState : public AutoGradCaptureState {
bool y_requires_grad = false;
bool dy_requires_grad = false;
};
class SoftmaxGradGrad : public OpExprGradFunction<SoftmaxGradGradCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(SoftmaxGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const SoftmaxGradGradCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};
Maybe<void> SoftmaxGradGrad::Init(const OpExpr& op) { return Maybe<void>::Ok(); }
Maybe<void> SoftmaxGradGrad::Capture(SoftmaxGradGradCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
// y, dy
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->y_requires_grad = inputs[0]->requires_grad();
ctx->dy_requires_grad = inputs[1]->requires_grad();
ctx->SaveTensorForBackward(inputs[0]);
if (ctx->y_requires_grad) ctx->SaveTensorForBackward(inputs[1]);
return Maybe<void>::Ok();
}
Maybe<void> SoftmaxGradGrad::Apply(const SoftmaxGradGradCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
in_grads->resize(2);
const auto& y = ctx->SavedTensors()[0];
if (ctx->y_requires_grad) {
const auto& dy = ctx->SavedTensors()[1];
const std::vector<int32_t> reduce_axis{static_cast<int32_t>(y->ndim() - 1)};
const auto& a = JUST(functional::sequence_function(functional::Mul)
.then(std::bind(functional::ReduceSum, std::placeholders::_1,
reduce_axis, /*keepdim=*/true))
.then(std::bind(functional::Mul, std::placeholders::_1, dy))
.call(y, out_grads[0]));
const auto& b = JUST(functional::sequence_function(functional::Mul)
.then(std::bind(functional::ReduceSum, std::placeholders::_1,
reduce_axis, /*keepdim=*/true))
.then(std::bind(functional::Mul, std::placeholders::_1, out_grads[0]))
.call(y, dy));
in_grads->at(0) = JUST(functional::sequence_function(functional::Mul)
.then(std::bind(functional::Sub, std::placeholders::_1, a,
/*alpha=*/1, /*inplace=*/false))
.then(std::bind(functional::Sub, std::placeholders::_1, b,
/*alpha=*/1, /*inplace=*/false))
.call(out_grads[0], dy));
}
if (ctx->dy_requires_grad) { in_grads->at(1) = JUST(functional::SoftmaxGrad(out_grads[0], y)); }
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("softmax_grad", SoftmaxGradGrad);
} // namespace one
} // namespace oneflow
......@@ -39,6 +39,8 @@ Maybe<void> RawCheckAsymmetricBroadcast(Symbol<PlacedNdSbp> in, Symbol<PlacedNdS
CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));
CHECK_OR_RETURN(out->placement()->Bigger(*in->placement())
|| in->placement()->Bigger(*out->placement()));
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
......@@ -76,16 +78,19 @@ Maybe<int64_t> CalBroadcastRoot(Symbol<ParallelDesc> src_parallel_desc,
static constexpr auto* CachedGetBroadcastRoot = DECORATE(&CalBroadcastRoot, ThreadLocalCached);
Maybe<one::UserOpExpr> EagerNcclBroadcast(Symbol<ParallelDesc> parallel_desc, int64_t root) {
return one::OpBuilder("eager_nccl_broadcast", *JUST(UniqueStr("eager_nccl_broadcast")))
Maybe<one::UserOpExpr> EagerCclBroadcast(Symbol<ParallelDesc> parallel_desc, int64_t root,
const Shape& shape) {
return one::OpBuilder("eager_ccl_broadcast", *JUST(UniqueStr("eager_ccl_broadcast")))
.Input("in")
.Output("out")
.Attr<std::string>("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf()))
.Attr<std::vector<Shape>>("shape_list", {shape})
.Attr<int64_t>("root", root)
.Build();
}
static constexpr auto* CachedEagerNcclBroadcast = DECORATE(&EagerNcclBroadcast, ThreadLocalCached);
static constexpr auto* CachedEagerCclBroadcast =
DECORATE(&EagerCclBroadcast, ThreadLocalCachedCopiable);
} // namespace
Maybe<one::Tensor> AsymmetricBroadcast(const std::shared_ptr<one::Tensor>& tensor,
......@@ -105,26 +110,19 @@ Maybe<one::Tensor> AsymmetricBroadcast(const std::shared_ptr<one::Tensor>& tenso
if (out->placement()->Bigger(*in->placement())) {
const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out_placement));
if (out_parallel_id->has_value()) {
const auto& in_parallel_id = JUST(GetParallelId4CurrentProcessCtx(in_placement));
if (!in_parallel_id->has_value()) {
const std::string& device_type = in_placement->device_tag();
local_tensor =
JUST(one::functional::Empty(*tensor->shape(), tensor->dtype(),
JUST(Device::New(device_type)), /*pin_memory=*/false));
}
const auto& broadcast_group = JUST(GetBroadcastGroup(in_placement, out_placement));
Symbol<ParallelDesc> broadcast_placement_cur_rank =
JUST(MapAt(*broadcast_group, GlobalProcessCtx::Rank()));
int64_t root = JUST(CachedGetBroadcastRoot(in_placement, broadcast_placement_cur_rank));
std::shared_ptr<one::UserOpExpr> op_expr =
JUST(CachedEagerNcclBroadcast(broadcast_placement_cur_rank, root));
JUST(CachedEagerCclBroadcast(broadcast_placement_cur_rank, root, *tensor->shape()));
local_tensor = JUST(one::OpInterpUtil::Dispatch<one::Tensor>(*op_expr, {local_tensor}));
}
}
return one::functional::LocalToConsistent(local_tensor, out_placement,
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),
tensor->dtype());
return one::functional::LocalToGlobal(local_tensor, out_placement,
*JUST(GetSbpList(out->nd_sbp())), *tensor->shape(),
tensor->dtype(), /* sync_data */ false, /*copy=*/false);
}
COMMAND(RegisterBoxingFunction("asymmetric-broadcast", CheckAsymmetricBroadcast,
......
......@@ -85,17 +85,15 @@ namespace std {
template<>
struct hash<oneflow::BoxingInterpreterStatus> {
size_t operator()(const oneflow::BoxingInterpreterStatus& status) const {
using namespace oneflow;
size_t ret = 0;
for (const auto& boxing_name : *status.sorted_boxing_names()) {
ret ^= std::hash<string>()(boxing_name);
}
const auto& placed_nd_sbp_hash = std::hash<oneflow::PlacedNdSbp>();
ret ^= placed_nd_sbp_hash(*status.src_placed_nd_sbp());
for (const auto& boxing_name : *status.sorted_boxing_names()) { AddHash(&ret, boxing_name); }
AddHash(&ret, *status.src_placed_nd_sbp());
for (const auto& mid_placed_nd_sbp : *status.mid_placed_nd_sbp()) {
ret ^= placed_nd_sbp_hash(*mid_placed_nd_sbp);
AddHash(&ret, *mid_placed_nd_sbp);
}
ret ^= placed_nd_sbp_hash(*status.dst_placed_nd_sbp());
return hash<size_t>()(ret);
AddHash(&ret, *status.dst_placed_nd_sbp());
return ret;
}
};
......
......@@ -13,16 +13,55 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
namespace oneflow {
namespace {
class EagerBoxingKernelRegContext final : public user_op::KernelRegContext {
public:
explicit EagerBoxingKernelRegContext(DeviceType device_type) : device_type_(device_type) {}
~EagerBoxingKernelRegContext() = default;
DeviceType device_type() const override { return device_type_; }
const ParallelContext& parallel_ctx() const override { PRINT_BUG_PROMPT_AND_ABORT(); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
PRINT_BUG_PROMPT_AND_ABORT();
}
const std::vector<std::pair<std::string, int32_t>>& inputs() const override {
PRINT_BUG_PROMPT_AND_ABORT();
}
const std::vector<std::pair<std::string, int32_t>>& outputs() const override {
PRINT_BUG_PROMPT_AND_ABORT();
}
const user_op::UserOpConfWrapper& user_op_conf() const override { PRINT_BUG_PROMPT_AND_ABORT(); }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
PRINT_BUG_PROMPT_AND_ABORT();
}
private:
DeviceType device_type_;
};
Maybe<bool> RawCheckCclKernelRegistered(const std::string& op_type_name, DeviceType device_type) {
EagerBoxingKernelRegContext reg_ctx(device_type);
return user_op::UserOpRegistryMgr::Get().IsOpKernelRegistered(op_type_name, reg_ctx);
}
static constexpr auto* CheckCclKernelRegistered =
DECORATE(&RawCheckCclKernelRegistered, ThreadLocalCachedCopiable);
Maybe<void> RawCheckCclP2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
const Shape& logical_shape) {
// NOLINTBEGIN(maybe-need-error-msg)
......@@ -33,8 +72,9 @@ Maybe<void> RawCheckCclP2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
CHECK_OR_RETURN(NdSbpIsAllBroadcast(*out->nd_sbp()));
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
CHECK_OR_RETURN( // NOLINT
JUST(CheckCclKernelRegistered("eager_ccl_all_reduce", // NOLINT
in->placement()->device_type()))); // NOLINT
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
......@@ -53,8 +93,9 @@ Maybe<void> RawCheckCclP2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0);
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
CHECK_OR_RETURN( // NOLINT
JUST(CheckCclKernelRegistered("eager_ccl_reduce_scatter", // NOLINT
in->placement()->device_type()))); // NOLINT
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
......@@ -74,8 +115,9 @@ Maybe<void> RawCheckCclS2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
CHECK_OR_RETURN(logical_shape.At(0) % in->placement()->parallel_num() == 0);
CHECK_OR_RETURN(in->placement() == out->placement());
CHECK_OR_RETURN(in->placement()->device_type() == DeviceType::kCPU
|| in->placement()->device_type() == DeviceType::kCUDA);
CHECK_OR_RETURN( // NOLINT
JUST(CheckCclKernelRegistered("eager_ccl_all_gather", // NOLINT
in->placement()->device_type()))); // NOLINT
// NOLINTEND(maybe-need-error-msg)
return Maybe<void>::Ok();
}
......@@ -122,7 +164,7 @@ Maybe<one::Tensor> CclP2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<Pla
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
return JUST(one::functional::ConsistentAllReduce(tensor));
return JUST(one::functional::GlobalAllReduce(tensor));
}
Maybe<one::Tensor> CclP2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
......@@ -137,7 +179,7 @@ Maybe<one::Tensor> CclP2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<Pla
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
return JUST(one::functional::ConsistentReduceScatter(tensor, "sum"));
return JUST(one::functional::GlobalReduceScatter(tensor, "sum"));
}
Maybe<one::Tensor> CclS2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
......@@ -151,7 +193,7 @@ Maybe<one::Tensor> CclS2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<Pla
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
return JUST(one::functional::ConsistentAllGather(tensor));
return JUST(one::functional::GlobalAllGather(tensor));
}
Maybe<one::Tensor> CclS2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<PlacedNdSbp> in,
......@@ -165,7 +207,7 @@ Maybe<one::Tensor> CclS2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<Pla
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
return JUST(one::functional::ConsistentS2S(tensor, *JUST(GetSbpList(out->nd_sbp()))));
return JUST(one::functional::GlobalS2S(tensor, *JUST(GetSbpList(out->nd_sbp()))));
}
COMMAND(RegisterBoxingFunction("ccl-p-to-b", CheckCclP2B, &CclP2B));
......
......@@ -63,17 +63,11 @@ Maybe<one::Tensor> CopyBoxingFunction(const std::shared_ptr<one::Tensor>& tensor
<< Error::RuntimeError() << "The placement of input tensor ("
<< *JUST(PlacementToString(tensor_placement)) << ") must match the input placement ("
<< *JUST(PlacementToString(in->placement())) << ")";
std::shared_ptr<one::Tensor> local_tensor = JUST(tensor->cur_rank_phy_tensor());
const auto& out_parallel_id = JUST(GetParallelId4CurrentProcessCtx(out->placement()));
if (!out_parallel_id->has_value()) {
const std::string& device_type = tensor_placement->device_tag();
local_tensor = JUST(one::functional::Empty(
*JUST(GetPhysicalShape(*tensor->shape(), *tensor_nd_sbp, *tensor_placement, 0)),
tensor->dtype(), JUST(Device::New(device_type)), /*pin_memory=*/false));
}
const std::shared_ptr<one::Tensor>& local_tensor = JUST(tensor->cur_rank_phy_tensor());
const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype()));
return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype(),
/* sync_data */ false, /*copy=*/false));
}
COMMAND(RegisterBoxingFunction("copy-h2d", &CheckCopyH2D, &CopyBoxingFunction));
......
......@@ -38,7 +38,7 @@ Maybe<one::Tensor> EagerBoxingInterpreter::Interpret(const std::shared_ptr<one::
Symbol<ParallelDesc> in_parallel_desc,
Symbol<ParallelDesc> out_parallel_desc) const {
JUST(CheckEagerBoxingDataType(input->dtype()->data_type()));
DisableCheckConsistentTensorMetaScope disable_meta_check;
DisableCheckGlobalTensorMetaScope disable_meta_check;
const auto& tensor =
JUST(InterpretImpl(input, in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc));
const auto& tensor_nd_sbp = JUST(tensor->nd_sbp());
......
......@@ -38,6 +38,13 @@ Maybe<BoxingExprIf> OptionalCudaCopy(const std::shared_ptr<BoxingExprIf>& core_b
core_boxing_expr, JUST(OptionalBoxing("copy-d2h"))))));
}
Maybe<BoxingExprIf> OptionalCpuCopy(const std::shared_ptr<BoxingExprIf>& core_boxing_expr) {
return JUST(BoxingExpr(JUST(ReplaceInDeviceType(DeviceType::kCPU)),
JUST(OptionalBoxing("copy-d2h")),
JUST(BoxingExpr(JUST(ReplaceOutDeviceType(DeviceType::kCPU)),
core_boxing_expr, JUST(OptionalBoxing("copy-h2d"))))));
}
Maybe<BoxingExprIf> SymmetricOneDimSxToBBoxingExpr() {
return JUST(BoxingExpr(JUST(InPlacementAndSplit(0)), JUST(OptionalBoxing("ccl-s-to-s")),
JUST(BoxingExpr("ccl-s-to-b"))));
......@@ -152,7 +159,7 @@ Maybe<BoxingExprIf> RawMainBoxingExpr() {
| JUST(SymmetricNDimToOneDimBoxingExpr())
| JUST(GenericBoxingExpr());
// clang-format on
return core | JUST(OptionalCudaCopy(core));
return core | JUST(OptionalCudaCopy(core)) | JUST(OptionalCpuCopy(core));
}
} // namespace
......
......@@ -69,8 +69,9 @@ Maybe<one::Tensor> FlattenHierarchy(const std::shared_ptr<one::Tensor>& tensor,
<< *JUST(PlacementToString(in->placement())) << ")";
const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor());
const auto& sbp_list = JUST(GetSbpList(out->nd_sbp()));
return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype()));
return JUST(one::functional::LocalToGlobal(local_tensor, out->placement(), *sbp_list,
*tensor->shape(), tensor->dtype(),
/* sync_data */ false, /*copy=*/true));
}
COMMAND(RegisterBoxingFunction("flatten-hierarchy", CheckFlattenHierarchy, &FlattenHierarchy));
......
......@@ -163,9 +163,9 @@ Maybe<one::Tensor> GenericSymmetricNdSbpBoxing(const std::shared_ptr<one::Tensor
<< Error::RuntimeError() << "Invalid input tensor, size of local tensor ("
<< local_tensor->shape()->ToString() << ") does not match global tensor ("
<< logical_shape->ToString() << ")!";
std::shared_ptr<one::Tensor> sub_global_tensor = JUST(one::functional::LocalToConsistent(
std::shared_ptr<one::Tensor> sub_global_tensor = JUST(one::functional::LocalToGlobal(
local_tensor, sub_parallel_desc, *JUST(GetSbpList(one_dim_nd_sbp)), sub_logical_shape,
local_tensor->dtype()));
local_tensor->dtype(), /* sync_data */ false, /*copy=*/false));
sub_global_tensor =
JUST(Apply1DBoxing(sub_global_tensor, one_dim_nd_sbp, JUST(SbpToNdSbp(broadcast_sbp)),
......@@ -175,9 +175,9 @@ Maybe<one::Tensor> GenericSymmetricNdSbpBoxing(const std::shared_ptr<one::Tensor
const auto& new_nd_sbp = JUST(SetSbpAtAxis(*nd_sbp, *broadcast_sbp, i));
output = JUST(one::functional::LocalToConsistent(local_tensor, in_parallel_desc,
*JUST(GetSbpList(new_nd_sbp)),
*logical_shape, local_tensor->dtype()));
output = JUST(one::functional::LocalToGlobal(
local_tensor, in_parallel_desc, *JUST(GetSbpList(new_nd_sbp)), *logical_shape,
local_tensor->dtype(), /* sync_data */ false, /*copy=*/false));
}
CHECK_OR_RETURN(IsAllBroadcastNdSbpAfterDim(JUST(output->nd_sbp()), first_diff_sbp_dim))
......@@ -202,9 +202,9 @@ Maybe<one::Tensor> GenericSymmetricNdSbpBoxing(const std::shared_ptr<one::Tensor
std::shared_ptr<one::Tensor> local_tensor = JUST(output->cur_rank_phy_tensor());
std::shared_ptr<one::Tensor> sub_global_tensor = JUST(one::functional::LocalToConsistent(
std::shared_ptr<one::Tensor> sub_global_tensor = JUST(one::functional::LocalToGlobal(
local_tensor, sub_parallel_desc, *JUST(GetSbpList(JUST(SbpToNdSbp(broadcast_sbp)))),
*sub_logical_shape, local_tensor->dtype()));
*sub_logical_shape, local_tensor->dtype(), /* sync_data */ false, /*copy=*/false));
const auto& one_dim_nd_sbp = JUST(SbpToNdSbp(sbp_parallel));
sub_global_tensor = JUST(Apply1DBoxing(sub_global_tensor, JUST(SbpToNdSbp(broadcast_sbp)),
......@@ -223,18 +223,18 @@ Maybe<one::Tensor> GenericSymmetricNdSbpBoxing(const std::shared_ptr<one::Tensor
const auto& new_nd_sbp = JUST(SetSbpAtAxis(*nd_sbp, sbp_parallel, i));
output = JUST(one::functional::LocalToConsistent(local_tensor, in_parallel_desc,
*JUST(GetSbpList(new_nd_sbp)),
*logical_shape, local_tensor->dtype()));
output = JUST(one::functional::LocalToGlobal(
local_tensor, in_parallel_desc, *JUST(GetSbpList(new_nd_sbp)), *logical_shape,
local_tensor->dtype(), /* sync_data */ false, /*copy=*/false));
// physical_shape of this axis is logical shape of next axis
sub_logical_shape = physical_shape;
}
} else {
one::ConsistentTensorMeta tensor_meta(input->shape(), input->dtype()->data_type(), out_nd_sbp,
out_parallel_desc);
const auto& tensor_impl = JUST(
one::EagerConsistentTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false));
output = std::make_shared<one::ConsistentTensor>(tensor_impl);
one::GlobalTensorMeta tensor_meta(*input->shape(), input->dtype()->data_type(), out_nd_sbp,
out_parallel_desc);
const auto& tensor_impl =
JUST(one::EagerGlobalTensorImpl::New(SymbolOf(tensor_meta), input->requires_grad(), false));
output = std::make_shared<one::GlobalTensor>(tensor_impl);
}
return output;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment