Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_
#include "oneflow/core/graph/op_graph.h"
namespace oneflow {
namespace auto_parallel {
// Judge whether we need the same SBP for both producer and consumer
bool RequireSameSbp(const OpNode* consumer, const std::string& ibn);
} // namespace auto_parallel
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_
...@@ -17,17 +17,25 @@ limitations under the License. ...@@ -17,17 +17,25 @@ limitations under the License.
#include <memory> #include <memory>
#include <stack> #include <stack>
#include <queue> #include <queue>
#include "fmt/core.h"
#include "fmt/format.h"
#include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/autograd/autograd_meta.h" #include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_arg.h" #include "oneflow/core/framework/tensor_arg.h"
#include "oneflow/core/framework/tensor_methods.h"
#include "oneflow/core/framework/tensor_util.h"
#include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/functional/functional.h" #include "oneflow/core/functional/functional.h"
#include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/global_param_grad_sync_mode.h" #include "oneflow/core/framework/global_param_grad_sync_mode.h"
#include "oneflow/core/common/container_util.h" #include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/common/env_var/debug_mode.h"
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
namespace oneflow { namespace oneflow {
namespace one { namespace one {
...@@ -75,20 +83,16 @@ bool IsReadyToRun(const std::vector<std::shared_ptr<AutogradMeta>>& out_meta_dat ...@@ -75,20 +83,16 @@ bool IsReadyToRun(const std::vector<std::shared_ptr<AutogradMeta>>& out_meta_dat
Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) { Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) {
autograd::AutoGradMode mode(autograd_mode); autograd::AutoGradMode mode(autograd_mode);
auto current_grad = JUST(autograd_meta->current_grad()->GetAccTensor({})); auto current_grad = JUST(autograd_meta->current_grad_value());
if (!current_grad) { return Maybe<void>::Ok(); } if (!current_grad) { return Maybe<void>::Ok(); }
if (autograd_meta->acc_grad()) { if (autograd_meta->acc_grad()) {
// Should not inplace accumulate grad. For example, JUST(functional::Add(autograd_meta->acc_grad(), current_grad, /*alpha=*/1.0,
// >>> z = x + y /*inplace=*/true));
// >>> p = x / z
// >>> p.sum().backward()
//
// As we know that dx = dz + dp / z and dy = dz, so it will lead to wrong value
// for dy if dx is shared with dz.
const auto& output = JUST(functional::Add(autograd_meta->acc_grad(), current_grad, /*alpha=*/1,
/*inplace=*/autograd_meta->is_grad_acc_inplace()));
JUST(autograd_meta->set_acc_grad(output));
} else { } else {
// NOTE: acc_grad can not share data with current_grad, because accumulate acc_grad
// with inplace operation and it maybe change current_grad to get wrong result.
// See more details in https://github.com/Oneflow-Inc/oneflow/issues/8248
if (!LazyMode::is_enabled()) { current_grad = JUST(functional::Identity(current_grad)); }
JUST(autograd_meta->set_acc_grad(current_grad)); JUST(autograd_meta->set_acc_grad(current_grad));
} }
for (const auto& hook : autograd_meta->post_grad_accumulation_hooks()) { for (const auto& hook : autograd_meta->post_grad_accumulation_hooks()) {
...@@ -99,47 +103,50 @@ Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) { ...@@ -99,47 +103,50 @@ Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<void> RawTorchConsistentTensor(const std::shared_ptr<one::Tensor>& tensor) { Maybe<void> RawTouchGlobalTensor(const std::shared_ptr<one::Tensor>& tensor) {
// Do nothing. // Do nothing.
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
static constexpr auto* TorchConsistentTensor = static constexpr auto* TouchGlobalTensor = DECORATE(&RawTouchGlobalTensor, CheckGlobalTensorMeta);
DECORATE(&RawTorchConsistentTensor, CheckConsistentTensorMeta);
Maybe<void> CheckConsistentTensorsMeta(const TensorTuple& tensor_tuple) { Maybe<void> CheckGlobalTensorsMeta(const TensorTuple& tensor_tuple) {
for (const auto& tensor : tensor_tuple) { for (const auto& tensor : tensor_tuple) {
if (tensor->is_consistent()) { JUST(TorchConsistentTensor(tensor)); } if (tensor->is_global() && tensor->is_eager()) { JUST(TouchGlobalTensor(tensor)); }
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
std::string GetDebugGraphFileName(const std::string& mode, const std::string& suffix) {
return fmt::format("autograd_{}_rank{}_suffix_graph.dot", mode, GlobalProcessCtx::Rank(), suffix);
}
} // namespace } // namespace
Maybe<void> AutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs, Maybe<void> AutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs,
const TensorTuple& out_grads, const TensorTuple& out_grads,
bool retain_graph, bool retain_graph,
bool create_graph) { bool create_graph) {
JUST(CheckConsistentTensorsMeta(outputs)); JUST(CheckGlobalTensorsMeta(outputs));
JUST(CheckConsistentTensorsMeta(out_grads)); JUST(CheckGlobalTensorsMeta(out_grads));
DisableCheckConsistentTensorMetaScope disable_meta_check; DisableCheckGlobalTensorMetaScope disable_meta_check;
return RunBackwardAndSaveGrads4LeafTensor(outputs, out_grads, retain_graph, create_graph); return RunBackwardAndSaveGrads4LeafTensor(outputs, out_grads, retain_graph, create_graph);
} }
Maybe<TensorTuple> AutogradEngine::RunBackwardAndReturnInputsTensorGradIf( Maybe<TensorTuple> AutogradEngine::RunBackwardAndReturnInputsTensorGradIf(
const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads,
bool retain_graph, bool create_graph) { bool retain_graph, bool create_graph) {
JUST(CheckConsistentTensorsMeta(outputs)); JUST(CheckGlobalTensorsMeta(outputs));
JUST(CheckConsistentTensorsMeta(inputs)); JUST(CheckGlobalTensorsMeta(inputs));
JUST(CheckConsistentTensorsMeta(out_grads)); JUST(CheckGlobalTensorsMeta(out_grads));
DisableCheckConsistentTensorMetaScope disable_meta_check; DisableCheckGlobalTensorMetaScope disable_meta_check;
return RunBackwardAndReturnInputsTensorGrad(outputs, inputs, out_grads, retain_graph, return RunBackwardAndReturnInputsTensorGrad(outputs, inputs, out_grads, retain_graph,
create_graph); create_graph);
} }
Maybe<void> FunctionNode::AccGrad4RetainGradTensor() { Maybe<void> FunctionNode::AccGrad4RetainGradTensor(bool create_graph) {
for (const std::shared_ptr<AutogradMeta>& out : output_meta_data_) { for (const std::shared_ptr<AutogradMeta>& out : output_meta_data_) {
if (out->retain_grad()) { JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/false)); } if (out->retain_grad()) { JUST(CopyOrAccGrad(out.get(), create_graph)); }
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
...@@ -149,17 +156,18 @@ Maybe<void> FunctionNode::AccGrad4LeafTensor(bool create_graph) { ...@@ -149,17 +156,18 @@ Maybe<void> FunctionNode::AccGrad4LeafTensor(bool create_graph) {
auto& out = output_meta_data_[i]; auto& out = output_meta_data_[i];
if (out->is_leaf() && out->requires_grad()) { if (out->is_leaf() && out->requires_grad()) {
JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/false)); JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/create_graph));
// control acc_grad to do boxing conditionally // control acc_grad to do boxing conditionally
const auto& acc_grad = out->acc_grad(); const auto& acc_grad = out->acc_grad();
if (GlobalGradSyncMode::is_enabled() && acc_grad->is_consistent()) { if (!LazyMode::is_enabled() && GlobalGradSyncMode::is_enabled() && acc_grad->is_global()
&& acc_grad->is_eager()) {
auto& tensor_info = output_tensor_infos_[i]; auto& tensor_info = output_tensor_infos_[i];
const auto& placement = JUST(tensor_info.placement()); const auto& placement = JUST(tensor_info.placement());
const auto& nd_sbp = JUST(tensor_info.sbp()); const auto& nd_sbp = JUST(tensor_info.sbp());
JUST(out->set_acc_grad( JUST(out->set_acc_grad(
JUST(functional::ToConsistent(acc_grad, placement, *JUST(GetSbpList(nd_sbp)), JUST(functional::ToGlobal(acc_grad, placement, *JUST(GetSbpList(nd_sbp)),
GetNoneSbpList(), /* check_meta */ false)))); GetNoneSbpList(), /* check_meta */ false, /*copy=*/false))));
} }
} }
} }
...@@ -182,22 +190,30 @@ Maybe<bool> FunctionNode::Apply(bool create_graph) { ...@@ -182,22 +190,30 @@ Maybe<bool> FunctionNode::Apply(bool create_graph) {
TensorTuple output_grads(output_meta_data_.size()); TensorTuple output_grads(output_meta_data_.size());
for (int i = 0; i < output_meta_data_.size(); ++i) { for (int i = 0; i < output_meta_data_.size(); ++i) {
if (output_meta_data_.at(i)->current_grad()->Empty()) { if (output_meta_data_.at(i)->current_grad()->Empty()) {
output_grads.at(i) = JUST(output_tensor_infos_.at(i).zeros()); // Only initialize out_grads for those requires_grad outputs
if (output_meta_data_[i]->requires_grad()) {
output_grads[i] = JUST(output_tensor_infos_[i].zeros());
}
} else { } else {
const auto& hooks = JUST(oneflow::VectorAt(output_meta_data_, i))->hooks();
JUST(oneflow::VectorAt(output_grads, i)) = JUST(oneflow::VectorAt(output_grads, i)) =
JUST(JUST(oneflow::VectorAt(output_meta_data_, i))->current_grad()->GetAccTensor(hooks)); JUST(JUST(oneflow::VectorAt(output_meta_data_, i))->current_grad_value());
} }
} }
JUST(backward_fn_->body(output_grads, &input_grads, create_graph)); JUST(backward_fn_->body(output_grads, &input_grads, create_graph));
for (int i = 0; i < input_meta_data_.size(); ++i) { for (int i = 0; i < input_meta_data_.size(); ++i) {
if (JUST(VectorAt(input_grads, i))) { if (JUST(VectorAt(input_grads, i))) {
CHECK_NOTNULL_OR_RETURN(input_meta_data_.at(i)) CHECK_NOTNULL_OR_RETURN(input_meta_data_[i])
<< name_ << name_
<< " calculate grad for tensor which requires_grad is False. Please submit an issue in " << " calculate grad for tensor which requires_grad is False. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as " "`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possible"; "possible";
JUST(input_meta_data_.at(i)->current_grad()->PushPartialTensor(input_grads.at(i))); JUST(input_meta_data_[i]->current_grad()->PushPartialTensor(JUST(VectorAt(input_grads, i))));
} else {
CHECK_OR_RETURN(!input_meta_data_[i])
<< name() << "'s input[" << i
<< "] need calculate grad but got nullptr. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possible;";
} }
} }
return true; return true;
...@@ -247,15 +263,64 @@ GraphTask::GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_ ...@@ -247,15 +263,64 @@ GraphTask::GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_
for (const auto& out_tensor : outputs) { for (const auto& out_tensor : outputs) {
FunctionNode* node = out_tensor->mut_grad_fn_node().get(); FunctionNode* node = out_tensor->mut_grad_fn_node().get();
roots_.emplace_back(node); roots_.emplace_back(node);
dependencies_.insert(std::make_pair(node, 0));
} }
} }
Maybe<void> GraphTask::WriteGraphToDotFile(const std::string& file_name) const {
auto ExecInfoToDotString = [](const ExecInfo& exec_info) -> std::string {
std::stringstream ss;
ss << "ExecInfo{\\l";
ss << "\tdependencies: " << exec_info.dependencies << "\\l";
ss << "\tneed_execute: " << exec_info.need_execute << "\\l";
if (exec_info.capture_indices) {
ss << "\tcapture_indices: [";
for (const auto& out_idx_and_capture_idx : *exec_info.capture_indices) {
ss << out_idx_and_capture_idx.second << ", ";
}
ss << "]\\l";
}
ss << "}\\l";
return ss.str();
};
auto log_stream = TeePersistentLogStream::Create(file_name);
std::vector<std::string> lines;
lines.emplace_back("digraph AutogradTaskGraph {");
lines.emplace_back("\tmargin=\"1.5\";");
lines.emplace_back("\tnode [shape=box];");
for (auto iter = grad_fn2exec_info_.begin(); iter != grad_fn2exec_info_.end(); ++iter) {
const FunctionNode* node = iter->first;
const ExecInfo& exec_info = iter->second;
// write label attribute
std::string node_color = "black";
if (exec_info.dependencies == 0 && exec_info.need_execute) { // start node
node_color = "red";
} else if (exec_info.need_execute && exec_info.capture_indices) { // end node
node_color = "green";
}
lines.emplace_back(fmt::format(
"\t\"{}\" [label=\"{}\\l{}\\l{}\", color={}];", static_cast<const void*>(node),
node->name(), static_cast<const void*>(node), ExecInfoToDotString(exec_info), node_color));
// write edge
for (const auto& next_fn : node->next_functions()) {
lines.emplace_back(fmt::format("\t\"{}\" -> \"{}\";", static_cast<const void*>(node),
static_cast<const void*>(next_fn.get())));
}
}
lines.emplace_back("}");
log_stream << fmt::format("{}", fmt::join(lines, "\n"));
log_stream->Flush();
return Maybe<void>::Ok();
}
// Computes the number of dependencies for each FunctionNode // Computes the number of dependencies for each FunctionNode
Maybe<void> GraphTask::ComputeDependencies() { Maybe<void> GraphTask::ComputeDependencies() {
HashSet<FunctionNode*> seen; HashSet<FunctionNode*> seen;
std::stack<FunctionNode*> stack; std::stack<FunctionNode*> stack;
for (FunctionNode* node : roots_) { stack.push(node); } for (FunctionNode* node : roots_) {
stack.push(node);
grad_fn2exec_info_[node].need_execute = true;
}
while (!stack.empty()) { while (!stack.empty()) {
FunctionNode* node = stack.top(); FunctionNode* node = stack.top();
...@@ -263,7 +328,9 @@ Maybe<void> GraphTask::ComputeDependencies() { ...@@ -263,7 +328,9 @@ Maybe<void> GraphTask::ComputeDependencies() {
if (/*bool has_seen=*/!seen.insert(node).second) { continue; } if (/*bool has_seen=*/!seen.insert(node).second) { continue; }
for (const auto& next_grad_fn : node->next_functions()) { for (const auto& next_grad_fn : node->next_functions()) {
FunctionNode* next_node = next_grad_fn.get(); FunctionNode* next_node = next_grad_fn.get();
dependencies_[next_node] += 1; ExecInfo& exec_info = grad_fn2exec_info_[next_node];
exec_info.dependencies += 1;
exec_info.need_execute = true;
if (seen.find(next_node) == seen.end()) { stack.push(next_node); } if (seen.find(next_node) == seen.end()) { stack.push(next_node); }
} }
} }
...@@ -288,9 +355,17 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs ...@@ -288,9 +355,17 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs
} }
}; };
for (const auto& input : inputs) { // initialize all variable to capture grad for input tensors
CHECK_NOTNULL_OR_RETURN(input->mut_grad_fn_node().get()); captured_grads_ = std::make_shared<TensorTuple>(inputs.size());
need_execute_.insert(input->mut_grad_fn_node().get()); for (int idx = 0; idx < inputs.size(); idx++) {
const auto& input = inputs[idx];
CHECK_NOTNULL_OR_RETURN(input->mut_grad_fn_node().get()); // NOLINT(maybe-need-error-msg)
ExecInfo& exec_info = grad_fn2exec_info_[input->mut_grad_fn_node().get()];
exec_info.need_execute = true;
if (!exec_info.capture_indices) {
exec_info.capture_indices = std::make_unique<std::vector<std::pair<size_t, size_t>>>();
}
exec_info.capture_indices->emplace_back(std::make_pair(input->get_grad_fn_output_index(), idx));
} }
HashSet<FunctionNode*> seen; HashSet<FunctionNode*> seen;
...@@ -305,18 +380,17 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs ...@@ -305,18 +380,17 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs
continue; continue;
} }
if (FunctionNode* node = frame.GetNextFunction()) { if (FunctionNode* node = frame.GetNextFunction()) {
dependencies_[node] += 1; grad_fn2exec_info_[node].dependencies += 1;
if (seen.find(node) == seen.end()) { if (seen.find(node) == seen.end()) {
stack.push(NodeFrame(node)); stack.push(NodeFrame(node));
continue; // recurse continue; // recurse
} }
} else { } else {
bool need_execute = grad_fn2exec_info_[frame.node_].need_execute |=
std::any_of(frame.node_->next_functions().begin(), frame.node_->next_functions().end(), std::any_of(frame.node_->next_functions().begin(), frame.node_->next_functions().end(),
[&](const std::shared_ptr<FunctionNode>& fn) { [&](const std::shared_ptr<FunctionNode>& fn) {
return need_execute_.find(fn.get()) != need_execute_.end(); return grad_fn2exec_info_[fn.get()].need_execute;
}); });
if (need_execute) { need_execute_.insert(frame.node_); }
seen.insert(frame.node_); seen.insert(frame.node_);
stack.pop(); stack.pop();
} }
...@@ -327,26 +401,38 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs ...@@ -327,26 +401,38 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs
Maybe<void> GraphTask::Apply(bool save_grad_for_leaf) { Maybe<void> GraphTask::Apply(bool save_grad_for_leaf) {
std::queue<FunctionNode*> queue; std::queue<FunctionNode*> queue;
for (FunctionNode* node : roots_) { for (FunctionNode* node : roots_) {
if (dependencies_[node] == 0) { queue.push(node); } if (grad_fn2exec_info_[node].dependencies == 0) { queue.push(node); }
} }
while (!queue.empty()) { while (!queue.empty()) {
FunctionNode* node = queue.front(); FunctionNode* node = queue.front();
queue.pop(); queue.pop();
if (!need_execute_.empty() && need_execute_.find(node) == need_execute_.end()) { auto& exec_info = grad_fn2exec_info_[node];
if (!exec_info.need_execute) {
node->ReleaseOutTensorArgs(); node->ReleaseOutTensorArgs();
continue; continue;
} }
BackwardPassScopeGuard backward_guard(node->scope());
if (/*bool not_ready_to_apply=*/!(JUST(node->Apply(create_graph_)))) { continue; } if (/*bool not_ready_to_apply=*/!(JUST(node->Apply(create_graph_)))) { continue; }
if (exec_info.capture_indices) {
CHECK_NOTNULL_OR_RETURN(captured_grads_.get()) << "captured grads in GraphTask is nullptr";
for (const auto& out_idx_and_capture_idx : *exec_info.capture_indices) {
JUST(VectorAt(*captured_grads_, out_idx_and_capture_idx.second)) =
JUST(JUST(VectorAt(node->output_meta_data_, out_idx_and_capture_idx.first))
->current_grad_value());
}
}
if (save_grad_for_leaf) { JUST(node->AccGrad4LeafTensor(create_graph_)); } if (save_grad_for_leaf) { JUST(node->AccGrad4LeafTensor(create_graph_)); }
JUST(node->AccGrad4RetainGradTensor()); JUST(node->AccGrad4RetainGradTensor(create_graph_));
node->ReleaseOutTensorArgs(); node->ReleaseOutTensorArgs();
if (!retain_graph_) { node->ReleaseData(); } if (!retain_graph_) { node->ReleaseData(); }
for (const auto& next_grad_fn : node->next_functions()) { for (const auto& next_grad_fn : node->next_functions()) {
FunctionNode* next_node = next_grad_fn.get(); FunctionNode* next_node = next_grad_fn.get();
dependencies_[next_node] -= 1; int32_t& dependencies = grad_fn2exec_info_[next_node].dependencies;
if (dependencies_[next_node] == 0) { queue.push(next_node); } dependencies -= 1;
if (dependencies == 0) { queue.push(next_node); }
} }
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
...@@ -361,6 +447,10 @@ Maybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor ...@@ -361,6 +447,10 @@ Maybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor
} }
GraphTask graph_task(outputs, retain_graph, create_graph); GraphTask graph_task(outputs, retain_graph, create_graph);
JUST(graph_task.ComputeDependencies()); JUST(graph_task.ComputeDependencies());
if (IsInDebugMode()) {
JUST(
graph_task.WriteGraphToDotFile(GetDebugGraphFileName("backward", std::to_string(clock()))));
}
JUST(graph_task.Apply(/*save_grad_for_leaf=*/true)); JUST(graph_task.Apply(/*save_grad_for_leaf=*/true));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
...@@ -368,34 +458,23 @@ Maybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor ...@@ -368,34 +458,23 @@ Maybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor
Maybe<TensorTuple> GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad( Maybe<TensorTuple> GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad(
const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads,
bool retain_graph, bool create_graph) { bool retain_graph, bool create_graph) {
std::shared_ptr<TensorTuple> input_current_grad = std::make_shared<TensorTuple>(inputs.size());
GraphTask graph_task(outputs, retain_graph, create_graph);
std::vector<bool> ori_retain_grad(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
ori_retain_grad.at(i) = inputs.at(i)->retain_grad();
JUST(inputs.at(i)->set_retain_grad(true));
}
for (int i = 0; i < outputs.size(); ++i) { for (int i = 0; i < outputs.size(); ++i) {
JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i))); JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i)));
} }
GraphTask graph_task(outputs, retain_graph, create_graph);
JUST(graph_task.ComputeDependenciesAndPruneNode(inputs)); JUST(graph_task.ComputeDependenciesAndPruneNode(inputs));
JUST(graph_task.Apply(/*save_grad_for_leaf=*/false)); if (IsInDebugMode()) {
JUST(graph_task.WriteGraphToDotFile(GetDebugGraphFileName("grad", std::to_string(clock()))));
// Gets input grads and resume retain_grad
for (int i = 0; i < inputs.size(); ++i) {
input_current_grad->at(i) = JUST(inputs.at(i)->acc_grad());
if (!ori_retain_grad.at(i)) {
JUST(inputs.at(i)->set_acc_grad(nullptr));
JUST(inputs.at(i)->set_retain_grad(false));
}
} }
return input_current_grad; JUST(graph_task.Apply(/*save_grad_for_leaf=*/false));
return graph_task.GetCapturedGrads();
} }
Maybe<FunctionNode> GraphAutogradEngine::AddNode( Maybe<FunctionNode> GraphAutogradEngine::AddNode(
const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn, const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn,
const TensorTuple& inputs, TensorTuple* outputs) { const TensorTuple& inputs, TensorTuple* outputs) {
OF_PROFILER_RANGE_PUSH("AddAccumulateFunctionNode");
// Firstly push function_node of tensor in stack which is leaf and requires_grad // Firstly push function_node of tensor in stack which is leaf and requires_grad
for (const std::shared_ptr<Tensor>& in_tensor : inputs) { for (const std::shared_ptr<Tensor>& in_tensor : inputs) {
if (in_tensor->is_leaf() && in_tensor->requires_grad()) { if (in_tensor->is_leaf() && in_tensor->requires_grad()) {
...@@ -403,11 +482,17 @@ Maybe<FunctionNode> GraphAutogradEngine::AddNode( ...@@ -403,11 +482,17 @@ Maybe<FunctionNode> GraphAutogradEngine::AddNode(
} }
} }
OF_PROFILER_RANGE_POP();
OF_PROFILER_RANGE_PUSH("set_grad_fn_node");
std::shared_ptr<FunctionNode> func_node = std::shared_ptr<FunctionNode> func_node =
GraphFunctionNode::New(name, backward_fn, inputs, *outputs); GraphFunctionNode::New(name, backward_fn, inputs, *outputs);
for (const std::shared_ptr<Tensor>& out_tensor : *outputs) { for (int i = 0; i < outputs->size(); ++i) {
const std::shared_ptr<Tensor>& out_tensor = JUST(VectorAt(*outputs, i));
out_tensor->set_grad_fn_node(func_node); out_tensor->set_grad_fn_node(func_node);
out_tensor->set_grad_fn_output_index(i);
} }
if (LazyMode::is_enabled()) { func_node->set_scope(JUST(GetCurrentScope())); }
OF_PROFILER_RANGE_POP();
return func_node; return func_node;
} }
...@@ -423,6 +508,10 @@ Maybe<void> AddAccumulateFunctionNode(const std::shared_ptr<Tensor>& tensor) { ...@@ -423,6 +508,10 @@ Maybe<void> AddAccumulateFunctionNode(const std::shared_ptr<Tensor>& tensor) {
backward_fn->status = []() { return false; }; backward_fn->status = []() { return false; };
tensor->set_grad_fn_node(GraphFunctionNode::New( tensor->set_grad_fn_node(GraphFunctionNode::New(
"accumulate_grad", backward_fn, /*inputs=*/TensorTuple{}, /*outputs*/ TensorTuple{tensor})); "accumulate_grad", backward_fn, /*inputs=*/TensorTuple{}, /*outputs*/ TensorTuple{tensor}));
tensor->set_grad_fn_output_index(0);
if (LazyMode::is_enabled()) {
tensor->mut_grad_fn_node()->set_scope(JUST(GetTensorScope(tensor)));
}
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
......
...@@ -17,12 +17,15 @@ limitations under the License. ...@@ -17,12 +17,15 @@ limitations under the License.
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_ #ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_ #define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
#include <functional>
#include <list> #include <list>
#include <vector>
#include <memory> #include <memory>
#include <functional> #include <vector>
#include "oneflow/core/common/util.h"
#include "oneflow/core/autograd/autograd_meta.h" #include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/scope_util.h"
#include "oneflow/core/job/lazy_mode.h"
namespace oneflow { namespace oneflow {
...@@ -45,7 +48,7 @@ class FunctionNode { ...@@ -45,7 +48,7 @@ class FunctionNode {
Maybe<bool> Apply(bool create_graph); Maybe<bool> Apply(bool create_graph);
Maybe<void> AccGrad4LeafTensor(bool create_graph); Maybe<void> AccGrad4LeafTensor(bool create_graph);
Maybe<void> AccGrad4RetainGradTensor(); Maybe<void> AccGrad4RetainGradTensor(bool create_graph);
void ReleaseOutTensorArgs(); void ReleaseOutTensorArgs();
// Releases the eventual c++ std::function for backward if retain_graph=False to avoid calling // Releases the eventual c++ std::function for backward if retain_graph=False to avoid calling
// `Apply` in second time // `Apply` in second time
...@@ -56,10 +59,14 @@ class FunctionNode { ...@@ -56,10 +59,14 @@ class FunctionNode {
} }
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
const std::shared_ptr<Scope>& scope() const { return scope_; }
void set_scope(const std::shared_ptr<Scope>& scope) { scope_ = scope; }
protected: protected:
friend class GraphTask;
explicit FunctionNode(const std::string& name, explicit FunctionNode(const std::string& name,
const std::shared_ptr<BackwardFunction>& backward_fn) const std::shared_ptr<BackwardFunction>& backward_fn)
: name_(name), backward_fn_(backward_fn) {} : name_(name), backward_fn_(backward_fn), scope_(nullptr) {}
const std::string name_; const std::string name_;
std::vector<std::shared_ptr<FunctionNode>> next_functions_; std::vector<std::shared_ptr<FunctionNode>> next_functions_;
...@@ -70,6 +77,9 @@ class FunctionNode { ...@@ -70,6 +77,9 @@ class FunctionNode {
// Actual backward function builds in `AutogradInterpreter` to calculate one backward op // Actual backward function builds in `AutogradInterpreter` to calculate one backward op
std::shared_ptr<BackwardFunction> backward_fn_; std::shared_ptr<BackwardFunction> backward_fn_;
// The execution scope
std::shared_ptr<Scope> scope_;
}; };
class AutogradEngine { class AutogradEngine {
...@@ -130,13 +140,26 @@ class GraphTask final { ...@@ -130,13 +140,26 @@ class GraphTask final {
Maybe<void> ComputeDependencies(); Maybe<void> ComputeDependencies();
Maybe<void> ComputeDependenciesAndPruneNode(const TensorTuple& inputs); Maybe<void> ComputeDependenciesAndPruneNode(const TensorTuple& inputs);
Maybe<void> Apply(bool save_grad_for_leaf); Maybe<void> Apply(bool save_grad_for_leaf);
std::shared_ptr<TensorTuple> GetCapturedGrads() const { return captured_grads_; }
Maybe<void> WriteGraphToDotFile(const std::string& file_name) const;
private: private:
class ExecInfo {
public:
ExecInfo() = default;
int32_t dependencies = 0;
bool need_execute = false;
// Used in autograd.grad interface, to record which grad of tensor will be captured.
// The pair means: <output index of this Node, the index of captured_grads_ to be saved>
std::unique_ptr<std::vector<std::pair<size_t, size_t>>> capture_indices;
};
bool retain_graph_; bool retain_graph_;
bool create_graph_; bool create_graph_;
std::vector<FunctionNode*> roots_; std::vector<FunctionNode*> roots_;
HashMap<FunctionNode*, int> dependencies_; HashMap<FunctionNode*, ExecInfo> grad_fn2exec_info_;
HashSet<FunctionNode*> need_execute_; std::shared_ptr<TensorTuple> captured_grads_;
}; };
class GraphAutogradEngine final : public AutogradEngine { class GraphAutogradEngine final : public AutogradEngine {
......
...@@ -25,9 +25,12 @@ namespace oneflow { ...@@ -25,9 +25,12 @@ namespace oneflow {
namespace one { namespace one {
TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) { TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {
if (TRY(tensor.device()).IsOk()) { device_ = CHECK_JUST(tensor.device()); } if (tensor.is_global()) {
if (TRY(tensor.parallel_desc()).IsOk()) { parallel_desc_ = CHECK_JUST(tensor.parallel_desc()); } parallel_desc_ = CHECK_JUST(tensor.parallel_desc());
if (TRY(tensor.nd_sbp()).IsOk()) { nd_sbp_ = CHECK_JUST(tensor.nd_sbp()); } nd_sbp_ = CHECK_JUST(tensor.nd_sbp());
} else {
device_ = CHECK_JUST(tensor.device());
}
} }
Maybe<const std::vector<Symbol<SbpParallel>>&> GetSbpTuple(Symbol<NdSbp> nd_sbp) { Maybe<const std::vector<Symbol<SbpParallel>>&> GetSbpTuple(Symbol<NdSbp> nd_sbp) {
...@@ -52,7 +55,7 @@ Maybe<Tensor> TensorInfo::zeros() const { ...@@ -52,7 +55,7 @@ Maybe<Tensor> TensorInfo::zeros() const {
const auto& parallel_desc = JUST(parallel_desc_); const auto& parallel_desc = JUST(parallel_desc_);
const auto& nd_sbp = JUST(nd_sbp_); const auto& nd_sbp = JUST(nd_sbp_);
const auto& sbp_tuple = JUST(GetSbpTuple(nd_sbp)); const auto& sbp_tuple = JUST(GetSbpTuple(nd_sbp));
return functional::ConsistentConstant(*shape_.get(), 0, dtype_, parallel_desc, sbp_tuple); return functional::GlobalConstant(*shape_.get(), 0, dtype_, parallel_desc, sbp_tuple);
} }
} }
...@@ -60,18 +63,26 @@ AutogradMeta::AutogradMeta(bool requires_grad, bool is_leaf) ...@@ -60,18 +63,26 @@ AutogradMeta::AutogradMeta(bool requires_grad, bool is_leaf)
: is_leaf_(is_leaf), : is_leaf_(is_leaf),
requires_grad_(requires_grad), requires_grad_(requires_grad),
retain_grad_(false), retain_grad_(false),
is_grad_acc_inplace_(false),
current_grad_(new TensorArg) {} current_grad_(new TensorArg) {}
Maybe<void> AutogradMeta::set_acc_grad(const std::shared_ptr<Tensor>& grad) { Maybe<void> AutogradMeta::set_acc_grad(const std::shared_ptr<Tensor>& grad) {
if (const auto& static_zeros_tensor = std::dynamic_pointer_cast<StaticZerosTensor>(grad)) { if (const auto& static_zeros_tensor = std::dynamic_pointer_cast<StaticZerosTensor>(grad)) {
acc_grad_ = JUST(static_zeros_tensor->AsMirroredTensor()); acc_grad_ = JUST(static_zeros_tensor->AsLocalTensor());
} else { } else {
acc_grad_ = grad; acc_grad_ = grad;
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<Tensor> AutogradMeta::current_grad_value() const {
std::shared_ptr<Tensor> res = JUST(current_grad_->GetAccTensor());
for (const auto& hook : hooks_) {
const auto& new_tensor = hook(res);
if (new_tensor) { res = new_tensor; }
}
return res;
}
} // namespace one } // namespace one
} // namespace oneflow } // namespace oneflow
...@@ -36,7 +36,7 @@ namespace one { ...@@ -36,7 +36,7 @@ namespace one {
class Tensor; class Tensor;
class TensorArg; class TensorArg;
class MirroredTensor; class LocalTensor;
class AutogradMeta final { class AutogradMeta final {
public: public:
...@@ -46,7 +46,8 @@ class AutogradMeta final { ...@@ -46,7 +46,8 @@ class AutogradMeta final {
// Getters // Getters
const std::shared_ptr<Tensor>& acc_grad() const { return acc_grad_; } const std::shared_ptr<Tensor>& acc_grad() const { return acc_grad_; }
const std::shared_ptr<TensorArg>& current_grad() const { return current_grad_; } const std::shared_ptr<TensorArg>& current_grad() const { return current_grad_; }
bool is_grad_acc_inplace() const { return is_grad_acc_inplace_; } // get current grad processed by hooks
Maybe<Tensor> current_grad_value() const;
bool requires_grad() const { return requires_grad_; } bool requires_grad() const { return requires_grad_; }
bool is_leaf() const { return is_leaf_; } bool is_leaf() const { return is_leaf_; }
bool retain_grad() const { return retain_grad_; } bool retain_grad() const { return retain_grad_; }
...@@ -59,7 +60,6 @@ class AutogradMeta final { ...@@ -59,7 +60,6 @@ class AutogradMeta final {
// Setters // Setters
Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad); Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad);
std::shared_ptr<Tensor> mut_acc_grad() { return acc_grad_; } std::shared_ptr<Tensor> mut_acc_grad() { return acc_grad_; }
void set_is_grad_acc_inplace(bool is_inplace) { is_grad_acc_inplace_ = is_inplace; }
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
void set_retain_grad(bool retain_grad) { retain_grad_ = retain_grad; } void set_retain_grad(bool retain_grad) { retain_grad_ = retain_grad; }
void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; } void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; }
...@@ -77,10 +77,6 @@ class AutogradMeta final { ...@@ -77,10 +77,6 @@ class AutogradMeta final {
// Only meaningful on non_leaf Tensors (must be false otherwise) // Only meaningful on non_leaf Tensors (must be false otherwise)
bool retain_grad_; bool retain_grad_;
// Control whether grad accumulation is inplace. Don't change it
// unless you know what you are doing
bool is_grad_acc_inplace_;
std::shared_ptr<Tensor> acc_grad_; std::shared_ptr<Tensor> acc_grad_;
std::shared_ptr<TensorArg> current_grad_; std::shared_ptr<TensorArg> current_grad_;
std::vector<Hook> hooks_; std::vector<Hook> hooks_;
...@@ -104,8 +100,8 @@ class TensorInfo final { ...@@ -104,8 +100,8 @@ class TensorInfo final {
std::shared_ptr<const Shape> shape_; std::shared_ptr<const Shape> shape_;
Symbol<DType> dtype_; Symbol<DType> dtype_;
Optional<Symbol<Device>> device_; // for local tensor Optional<Symbol<Device>> device_; // for local tensor
Optional<Symbol<ParallelDesc>> parallel_desc_; // for consistent tensor Optional<Symbol<ParallelDesc>> parallel_desc_; // for global tensor
Optional<Symbol<NdSbp>> nd_sbp_; // for consistent tensor Optional<Symbol<NdSbp>> nd_sbp_; // for global tensor
}; };
} // namespace one } // namespace one
......
...@@ -108,6 +108,50 @@ class GeLU : public BaseActivation { ...@@ -108,6 +108,50 @@ class GeLU : public BaseActivation {
} }
}; };
class FastGeLU : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::FastGeluGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
struct QuickGeluCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};
class QuickGeLU : public OpExprGradFunction<QuickGeluCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(QuickGeluCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const QuickGeluCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::QuickGeluGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
class HardSigmoid : public BaseActivation { class HardSigmoid : public BaseActivation {
public: public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
...@@ -558,6 +602,8 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("prelu", PReLU); ...@@ -558,6 +602,8 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("prelu", PReLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("threshold", Threshold); REGISTER_OP_EXPR_GRAD_FUNCTION("threshold", Threshold);
REGISTER_OP_EXPR_GRAD_FUNCTION("softplus", Softplus); REGISTER_OP_EXPR_GRAD_FUNCTION("softplus", Softplus);
REGISTER_OP_EXPR_GRAD_FUNCTION("softshrink", SoftShrink); REGISTER_OP_EXPR_GRAD_FUNCTION("softshrink", SoftShrink);
REGISTER_OP_EXPR_GRAD_FUNCTION("fast_gelu", FastGeLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("quick_gelu", QuickGeLU);
} // namespace one } // namespace one
} // namespace oneflow } // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct AdaptiveMaxPoolCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};
class AdaptiveMaxPoolNdGrad : public OpExprGradFunction<AdaptiveMaxPoolCaptureState> {
public:
using OpExprGradFunction<AdaptiveMaxPoolCaptureState>::Init;
Maybe<void> Init(const OpExpr& op, const int& ndims);
Maybe<void> Capture(AdaptiveMaxPoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const AdaptiveMaxPoolCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
int32_t ndims_ = 0;
};
Maybe<void> AdaptiveMaxPoolNdGrad::Init(const OpExpr& op, const int& ndims) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
ndims_ = ndims;
return Maybe<void>::Ok();
}
Maybe<void> AdaptiveMaxPoolNdGrad::Capture(AdaptiveMaxPoolCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(outputs.at(1));
return Maybe<void>::Ok();
}
Maybe<void> AdaptiveMaxPoolNdGrad::Apply(const AdaptiveMaxPoolCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(1);
in_grads->resize(1);
in_grads->at(0) = JUST(functional::AdaptiveMaxPoolNdGrad(x, out_grads.at(0), index, ndims_));
return Maybe<void>::Ok();
}
class AdaptiveMaxPool1dGrad final : public AdaptiveMaxPoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 1); }
};
class AdaptiveMaxPool2dGrad final : public AdaptiveMaxPoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 2); }
};
class AdaptiveMaxPool3dGrad final : public AdaptiveMaxPoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 3); }
};
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool1d", AdaptiveMaxPool1dGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool2d", AdaptiveMaxPool2dGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool3d", AdaptiveMaxPool3dGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
enum class AmpIdentityType {
kWhite = 0,
kBlack,
};
struct AmpIdentityCaptureState : public AutoGradCaptureState {};
template<AmpIdentityType type>
class AmpIdentityGrad : public OpExprGradFunction<AmpIdentityCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Capture(AmpIdentityCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
return Maybe<void>::Ok();
}
Maybe<void> Apply(const AmpIdentityCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(1);
if (type == AmpIdentityType::kWhite) {
(*in_grads)[0] = JUST(functional::AmpWhiteIdentity(out_grads[0]));
} else if (type == AmpIdentityType::kBlack) {
(*in_grads)[0] = JUST(functional::AmpBlackIdentity(out_grads[0]));
} else {
(*in_grads)[0] = out_grads[0];
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("amp_white_identity", AmpIdentityGrad<AmpIdentityType::kWhite>);
REGISTER_OP_EXPR_GRAD_FUNCTION("amp_black_identity", AmpIdentityGrad<AmpIdentityType::kBlack>);
} // namespace one
} // namespace oneflow
...@@ -22,9 +22,9 @@ namespace oneflow { ...@@ -22,9 +22,9 @@ namespace oneflow {
namespace one { namespace one {
struct AsStridedCaptureState : public AutoGradCaptureState { struct AsStridedCaptureState : public AutoGradCaptureState {
std::vector<int32_t> size; std::vector<int64_t> size;
std::vector<int32_t> stride; std::vector<int64_t> stride;
int32_t storage_offset = 0; int64_t storage_offset = 0;
bool requires_grad = false; bool requires_grad = false;
}; };
...@@ -55,9 +55,9 @@ Maybe<void> AsStrided::Capture(AsStridedCaptureState* ctx, const TensorTuple& in ...@@ -55,9 +55,9 @@ Maybe<void> AsStrided::Capture(AsStridedCaptureState* ctx, const TensorTuple& in
ctx->SaveTensorForBackward(inputs.at(0)); ctx->SaveTensorForBackward(inputs.at(0));
ComposedAttrMap composed_attrs(attrs, base_attrs_); ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("size")); ctx->size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("size"));
ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("stride")); ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("stride"));
ctx->storage_offset = JUST(composed_attrs.GetAttr<int32_t>("storage_offset")); ctx->storage_offset = JUST(composed_attrs.GetAttr<int64_t>("storage_offset"));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
...@@ -67,9 +67,9 @@ Maybe<void> AsStrided::Apply(const AsStridedCaptureState* ctx, const TensorTuple ...@@ -67,9 +67,9 @@ Maybe<void> AsStrided::Apply(const AsStridedCaptureState* ctx, const TensorTuple
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& input = ctx->SavedTensors().at(0); const auto& input = ctx->SavedTensors().at(0);
std::vector<int32_t> size = ctx->size; std::vector<int64_t> size = ctx->size;
std::vector<int32_t> stride = ctx->stride; std::vector<int64_t> stride = ctx->stride;
int32_t storage_offset = ctx->storage_offset; int64_t storage_offset = ctx->storage_offset;
in_grads->at(0) = in_grads->at(0) =
JUST(functional::AsStridedGrad(out_grads.at(0), input, size, stride, storage_offset)); JUST(functional::AsStridedGrad(out_grads.at(0), input, size, stride, storage_offset));
......
...@@ -20,7 +20,9 @@ namespace oneflow { ...@@ -20,7 +20,9 @@ namespace oneflow {
namespace one { namespace one {
struct BinaryCrossEntropyCaptureState : public AutoGradCaptureState { struct BinaryCrossEntropyCaptureState : public AutoGradCaptureState {
bool requires_grad = false; bool input_requires_grad = false;
bool target_requires_grad = false;
bool has_weight = false;
}; };
class BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureState> { class BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureState> {
...@@ -30,46 +32,42 @@ class BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureSt ...@@ -30,46 +32,42 @@ class BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureSt
const TensorTuple& outputs, const AttrMap& attrs) const override; const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyCaptureState* ctx, const TensorTuple& out_grads, Maybe<void> Apply(const BinaryCrossEntropyCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override; TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
}; };
Maybe<void> BinaryCrossEntropy::Init(const OpExpr& op) { Maybe<void> BinaryCrossEntropy::Init(const OpExpr& op) { return Maybe<void>::Ok(); }
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropy::Capture(BinaryCrossEntropyCaptureState* ctx, Maybe<void> BinaryCrossEntropy::Capture(BinaryCrossEntropyCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const { const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad(); CHECK_OR_RETURN(inputs.size() >= 2 && inputs.size() <= 3); // NOLINT(maybe-need-error-msg)
if (!ctx->requires_grad) { return Maybe<void>::Ok(); } ctx->input_requires_grad = inputs[0]->requires_grad();
ctx->target_requires_grad = inputs[1]->requires_grad();
ctx->has_weight = inputs.size() == 3;
ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->SaveTensorForBackward(inputs[0]); // input
ctx->SaveTensorForBackward(inputs.at(0)); // input ctx->SaveTensorForBackward(inputs[1]); // target
ctx->SaveTensorForBackward(inputs.at(1)); // target if (ctx->has_weight) {
if (inputs.size() == 3) { ctx->SaveTensorForBackward(inputs[2]); // weight
ctx->SaveTensorForBackward(inputs.at(2)); // weight
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<void> BinaryCrossEntropy::Apply(const BinaryCrossEntropyCaptureState* ctx, Maybe<void> BinaryCrossEntropy::Apply(const BinaryCrossEntropyCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const { const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0); CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),
const auto& input = ctx->SavedTensors().at(0); 2 + ctx->has_weight); // NOLINT(maybe-need-error-msg)
const auto& target = ctx->SavedTensors().at(1); in_grads->resize(2 + ctx->has_weight);
in_grads->resize(ctx->SavedTensors().size());
const auto& dy = out_grads[0];
const auto& input = ctx->SavedTensors()[0];
const auto& target = ctx->SavedTensors()[1];
const auto& weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[2]) : NullOpt;
if (ctx->SavedTensors().size() == 3) { if (ctx->input_requires_grad) {
const auto& weight = ctx->SavedTensors().at(2); (*in_grads)[0] = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, weight));
in_grads->at(0) = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, weight)); }
} else { if (ctx->target_requires_grad) {
in_grads->at(0) = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, NullOpt)); (*in_grads)[1] = JUST(functional::BinaryCrossEntropyLossTargetGrad(dy, input, target, weight));
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
......
...@@ -20,7 +20,9 @@ namespace oneflow { ...@@ -20,7 +20,9 @@ namespace oneflow {
namespace one { namespace one {
struct BinaryCrossEntropyWithLogitsCaptureState : public AutoGradCaptureState { struct BinaryCrossEntropyWithLogitsCaptureState : public AutoGradCaptureState {
bool requires_grad = false; bool input_requires_grad = false;
bool target_requires_grad = false;
bool has_weight = false;
bool has_pos_weight = false; bool has_pos_weight = false;
}; };
...@@ -47,53 +49,51 @@ Maybe<void> BinaryCrossEntropyWithLogits::Capture(BinaryCrossEntropyWithLogitsCa ...@@ -47,53 +49,51 @@ Maybe<void> BinaryCrossEntropyWithLogits::Capture(BinaryCrossEntropyWithLogitsCa
const TensorTuple& inputs, const TensorTuple& inputs,
const TensorTuple& outputs, const TensorTuple& outputs,
const AttrMap& attrs) const { const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad(); CHECK_OR_RETURN(inputs.size() >= 2 && inputs.size() <= 4); // NOLINT(maybe-need-error-msg)
if (!ctx->requires_grad) { return Maybe<void>::Ok(); } ctx->input_requires_grad = inputs[0]->requires_grad();
ctx->target_requires_grad = inputs[1]->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_); ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->has_pos_weight = JUST(composed_attrs.GetAttr<bool>("has_pos_weight")); ctx->has_pos_weight = JUST(composed_attrs.GetAttr<bool>("has_pos_weight"));
ctx->SaveTensorForBackward(inputs.at(0)); // input ctx->has_weight = inputs.size() == 4 || (inputs.size() == 3 && !ctx->has_pos_weight);
ctx->SaveTensorForBackward(inputs.at(1)); // target ctx->SaveTensorForBackward(inputs[0]); // input
ctx->SaveTensorForBackward(inputs[1]); // target
if (inputs.size() == 3) { if (inputs.size() == 3) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight or pos_weight ctx->SaveTensorForBackward(inputs[2]); // weight or pos_weight
} }
if (inputs.size() == 4) { if (inputs.size() == 4) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight ctx->SaveTensorForBackward(inputs[2]); // weight
ctx->SaveTensorForBackward(inputs.at(3)); // pos_weight ctx->SaveTensorForBackward(inputs[3]); // pos_weight
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<void> BinaryCrossEntropyWithLogits::Apply(const BinaryCrossEntropyWithLogitsCaptureState* ctx, Maybe<void> BinaryCrossEntropyWithLogits::Apply(const BinaryCrossEntropyWithLogitsCaptureState* ctx,
const TensorTuple& out_grads, const TensorTuple& out_grads,
TensorTuple* in_grads) const { TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0); CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),
const auto& input = ctx->SavedTensors().at(0); 2 + ctx->has_weight + ctx->has_pos_weight); // NOLINT(maybe-need-error-msg)
const auto& target = ctx->SavedTensors().at(1); const auto& dy = out_grads[0];
const auto& input = ctx->SavedTensors()[0];
const auto& target = ctx->SavedTensors()[1];
in_grads->resize(ctx->SavedTensors().size()); in_grads->resize(ctx->SavedTensors().size());
if (ctx->SavedTensors().size() == 3) { size_t pos_weight_index = ctx->has_weight ? 3 : 2;
if (ctx->has_pos_weight) { auto weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[2]) : NullOpt;
const auto& pos_weight = ctx->SavedTensors().at(2); auto pos_weight =
in_grads->at(0) = JUST( ctx->has_pos_weight ? Optional<one::Tensor>(ctx->SavedTensors()[pos_weight_index]) : NullOpt;
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, NullOpt, pos_weight));
} else { if (ctx->input_requires_grad) {
const auto& weight = ctx->SavedTensors().at(2); (*in_grads)[0] = JUST(
in_grads->at(0) = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, NullOpt));
}
} else if (ctx->SavedTensors().size() == 4) {
const auto& weight = ctx->SavedTensors().at(2);
const auto& pos_weight = ctx->SavedTensors().at(3);
in_grads->at(0) = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, pos_weight)); functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, pos_weight));
} else {
in_grads->at(0) =
JUST(functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, NullOpt, NullOpt));
} }
if (ctx->target_requires_grad) {
(*in_grads)[1] = JUST(functional::BinaryCrossEntropyWithLogitsLossTargetGrad(
dy, input, target, weight, pos_weight));
}
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits", BinaryCrossEntropyWithLogits); REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits", BinaryCrossEntropyWithLogits);
......
...@@ -21,8 +21,8 @@ namespace oneflow { ...@@ -21,8 +21,8 @@ namespace oneflow {
namespace one { namespace one {
struct BinaryCrossEntropyWithLogitsReduceMeanCaptureState : public AutoGradCaptureState { struct BinaryCrossEntropyWithLogitsReduceMeanCaptureState : public AutoGradCaptureState {
bool requires_grad = false; bool input_requires_grad = false;
bool has_pos_weight = false; bool target_requires_grad = false;
}; };
class BinaryCrossEntropyWithLogitsReduceMean class BinaryCrossEntropyWithLogitsReduceMean
...@@ -34,25 +34,19 @@ class BinaryCrossEntropyWithLogitsReduceMean ...@@ -34,25 +34,19 @@ class BinaryCrossEntropyWithLogitsReduceMean
const AttrMap& attrs) const override; const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, Maybe<void> Apply(const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override; const TensorTuple& out_grads, TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
}; };
Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Init(const OpExpr& op) { Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be null. ";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture( Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture(
BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& inputs, BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const { const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad(); CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
if (!ctx->requires_grad) { return Maybe<void>::Ok(); } ctx->input_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
ctx->target_requires_grad = JUST(VectorAt(inputs, 1))->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target
return Maybe<void>::Ok(); return Maybe<void>::Ok();
...@@ -61,14 +55,20 @@ Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture( ...@@ -61,14 +55,20 @@ Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture(
Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Apply( Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Apply(
const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& out_grads, const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const { TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out_grads size should be equal to 1. ";
const auto& dy = JUST(VectorAt(out_grads, 0)); const auto& dy = JUST(VectorAt(out_grads, 0));
const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0)); const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0));
const auto& target = JUST(VectorAt(ctx->SavedTensors(), 1)); const auto& target = JUST(VectorAt(ctx->SavedTensors(), 1));
in_grads->resize(ctx->SavedTensors().size()); in_grads->resize(2);
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossGrad(dy, input, target)); if (ctx->input_requires_grad) {
(*in_grads)[0] =
JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossGrad(dy, input, target));
}
if (ctx->target_requires_grad) {
(*in_grads)[1] =
JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossTargetGrad(dy, input, target));
}
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
......
...@@ -232,13 +232,12 @@ class BroadcastPow : public BroadcastBinaryGrad { ...@@ -232,13 +232,12 @@ class BroadcastPow : public BroadcastBinaryGrad {
TensorTuple* in_grads) const override { TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(ctx->x_index); const auto& x = ctx->SavedTensors().at(ctx->x_index);
const auto& y = ctx->SavedTensors().at(ctx->y_index); const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& z = ctx->SavedTensors().at(ctx->z_index);
in_grads->resize(2); in_grads->resize(2);
if (ctx->x_requires_grad) { if (ctx->x_requires_grad) {
in_grads->at(0) = JUST(functional::BroadcastPowXGrad(out_grads.at(0), x, y, z)); (*in_grads)[0] = JUST(functional::BroadcastPowXGrad(x, y, out_grads[0]));
} }
if (ctx->y_requires_grad) { if (ctx->y_requires_grad) {
in_grads->at(1) = JUST(functional::BroadcastPowYGrad(out_grads.at(0), x, y, z)); (*in_grads)[1] = JUST(functional::BroadcastPowYGrad(x, y, out_grads[0]));
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
...@@ -246,9 +245,8 @@ class BroadcastPow : public BroadcastBinaryGrad { ...@@ -246,9 +245,8 @@ class BroadcastPow : public BroadcastBinaryGrad {
protected: protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs, Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override { const TensorTuple& outputs) const override {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0)); ctx->x_index = ctx->SaveTensorForBackward(inputs[0]);
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1)); ctx->y_index = ctx->SaveTensorForBackward(inputs[1]);
ctx->z_index = ctx->SaveTensorForBackward(outputs.at(0));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
}; };
...@@ -348,5 +346,80 @@ class BroadcastMaximum : public BroadcastMinMax { ...@@ -348,5 +346,80 @@ class BroadcastMaximum : public BroadcastMinMax {
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_minimum", BroadcastMinimum); REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_minimum", BroadcastMinimum);
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_maximum", BroadcastMaximum); REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_maximum", BroadcastMaximum);
class BroadcastFMod : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& out_shape = *(JUST(VectorAt(out_grads, 0))->shape());
in_grads->resize(2);
if (ctx->x_requires_grad || ctx->y_requires_grad) {
const auto& x = JUST(VectorAt(ctx->SavedTensors(), ctx->x_index));
const auto& y = JUST(VectorAt(ctx->SavedTensors(), ctx->y_index));
auto broad_x_ = x;
auto broad_y_ = y;
if (ctx->broadcast_x) {
const auto& x_shape = *(x->shape());
const Shape& left_extended_x_shape =
CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes());
if (left_extended_x_shape == out_shape) {
broad_x_ = JUST(functional::ReshapeLike(x, JUST(VectorAt(out_grads, 0))));
} else {
const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> x_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_x_ = JUST(functional::BroadcastLike(x, JUST(VectorAt(out_grads, 0)), x_axis));
}
}
if (ctx->broadcast_y) {
const auto& y_shape = *(y->shape());
const Shape& left_extended_y_shape =
CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes());
if (left_extended_y_shape == out_shape) {
broad_y_ = JUST(functional::ReshapeLike(y, JUST(VectorAt(out_grads, 0))));
} else {
const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> y_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_y_ = JUST(functional::BroadcastLike(y, JUST(VectorAt(out_grads, 0)), y_axis));
}
}
if (ctx->x_requires_grad) {
if (ctx->broadcast_x) {
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::BroadcastReduceSumLike(JUST(VectorAt(out_grads, 0)), x));
} else {
JUST(VectorAt(*in_grads, 0)) = JUST(VectorAt(out_grads, 0));
}
}
if (ctx->y_requires_grad) {
auto result = JUST(functional::TruncDiv(broad_x_, broad_y_));
result = JUST(functional::Mul(JUST(VectorAt(out_grads, 0)), result));
JUST(functional::ScalarMul(result, Scalar(-1.f), true));
if (ctx->broadcast_y) {
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(result, y));
} else {
in_grads->at(1) = result;
}
}
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad && ctx->broadcast_x) {
ctx->x_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));
}
if (ctx->y_requires_grad) {
ctx->x_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));
ctx->y_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_fmod", BroadcastFMod);
} // namespace one } // namespace one
} // namespace oneflow } // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
struct BroadcastFModCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class BroadcastFMod : public OpExprGradFunction<BroadcastFModCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(BroadcastFModCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const BroadcastFModCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_fmod", BroadcastFMod);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct CastConsistentCaptureState : public AutoGradCaptureState {
Symbol<ParallelDesc> parallel_desc;
Symbol<NdSbp> nd_sbp;
std::shared_ptr<const Shape> shape;
Symbol<DType> dtype;
};
class CastToConsistent : public OpExprGradFunction<CastConsistentCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const CastToConsistentOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
const std::string& op_name = fw_op_expr->op_name();
grad_op_ = JUST(one::CastFromConsistentOpExpr::New(GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> Capture(CastConsistentCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
ctx->parallel_desc = JUST(interp_ctx.parallel_desc);
ctx->nd_sbp = JUST(GetDualNdSbp(JUST(interp_ctx.nd_sbp)));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CastConsistentCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
std::shared_ptr<Tensor> out_grad = out_grads.at(0);
CHECK_OR_RETURN(out_grad->is_consistent())
<< Error::RuntimeError()
<< "Expected global tensor for cast_to_consistent but got local tensor";
{
Symbol<NdSbp> nd_sbp_constraint = ctx->nd_sbp;
Symbol<ParallelDesc> parallel_desc_constraint = ctx->parallel_desc;
out_grad = JUST(functional::ToConsistent(out_grad, parallel_desc_constraint,
*JUST(GetSbpList(nd_sbp_constraint)),
GetNoneSbpList(), /* check_meta */ false));
}
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grad}));
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("cast_to_consistent", CastToConsistent);
class CastFromConsistent : public OpExprGradFunction<CastConsistentCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const CastFromConsistentOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
const std::string& op_name = fw_op_expr->op_name();
grad_op_ = JUST(one::CastToConsistentOpExpr::New(GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> Capture(CastConsistentCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
const auto& input = inputs.at(0);
CHECK_OR_RETURN(input->is_consistent())
<< Error::RuntimeError()
<< "Expected global tensor for cast_from_consistent but got local tensor";
ctx->parallel_desc = JUST(input->parallel_desc());
ctx->nd_sbp = JUST(input->nd_sbp());
ctx->shape = input->shape();
ctx->dtype = input->dtype();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CastConsistentCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& dual_nd_sbp = JUST(GetDualNdSbp(ctx->nd_sbp));
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", *ctx->shape));
JUST(attrs.SetAttr<DataType>("dtype", ctx->dtype->data_type()));
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(
*grad_op_, {out_grads.at(0)}, OpExprInterpContext(attrs, ctx->parallel_desc, dual_nd_sbp)));
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("cast_from_consistent", CastFromConsistent);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/optional.h"
namespace oneflow {
namespace one {
struct ConsistentToConsistentState : public AutoGradCaptureState {
Symbol<ParallelDesc> parallel_desc;
Symbol<NdSbp> nd_sbp;
};
class ConsistentToConsistentGradFunction : public OpExprGradFunction<ConsistentToConsistentState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const ConsistentToConsistentOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
grad_nd_sbp_ = fw_op_expr->grad_nd_sbp();
return Maybe<void>::Ok();
}
Maybe<void> Capture(ConsistentToConsistentState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->parallel_desc = JUST(inputs.at(0)->parallel_desc());
ctx->nd_sbp = JUST(inputs.at(0)->nd_sbp());
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ConsistentToConsistentState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& out_grad = out_grads.at(0);
CHECK_OR_RETURN(out_grad->is_consistent())
<< Error::RuntimeError()
<< "Expected global tensor for consistent_to_consistent but got local tensor";
in_grads->resize(1);
const auto& grad_nd_sbp = grad_nd_sbp_.value_or(JUST(out_grad->nd_sbp()));
const auto& grad_sbp_list = JUST(GetSbpList(grad_nd_sbp));
const auto& grad_grad_sbp_list = JUST(GetSbpList(ctx->nd_sbp));
(*in_grads)[0] = JUST(one::functional::ToConsistent(
out_grad, ctx->parallel_desc, *grad_sbp_list, *grad_grad_sbp_list, /* check_meta */ false));
return Maybe<void>::Ok();
}
private:
Optional<Symbol<NdSbp>> grad_nd_sbp_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("consistent_to_consistent", ConsistentToConsistentGradFunction);
} // namespace one
} // namespace oneflow
...@@ -26,6 +26,8 @@ namespace one { ...@@ -26,6 +26,8 @@ namespace one {
struct ConvolutionNdCaptureState : public AutoGradCaptureState { struct ConvolutionNdCaptureState : public AutoGradCaptureState {
bool input_requires_grad = false; bool input_requires_grad = false;
bool weight_requires_grad = false; bool weight_requires_grad = false;
bool has_bias = false;
bool bias_requires_grad = false;
size_t input_index; size_t input_index;
size_t weight_index; size_t weight_index;
...@@ -58,10 +60,17 @@ Maybe<void> ConvolutionNd::Init(const OpExpr& op) { ...@@ -58,10 +60,17 @@ Maybe<void> ConvolutionNd::Init(const OpExpr& op) {
Maybe<void> ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs, Maybe<void> ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const { const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) CHECK_OR_RETURN(inputs.size() == 2 || inputs.size() == 3); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->weight_requires_grad = inputs.at(1)->requires_grad(); ctx->weight_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->input_requires_grad && !ctx->weight_requires_grad) { return Maybe<void>::Ok(); } if (inputs.size() == 3) {
ctx->has_bias = true;
ctx->bias_requires_grad = inputs.at(2)->requires_grad();
}
if (!ctx->input_requires_grad && !ctx->weight_requires_grad && !ctx->bias_requires_grad) {
return Maybe<void>::Ok();
}
if (ctx->input_requires_grad) { if (ctx->input_requires_grad) {
ctx->weight_index = ctx->SaveTensorForBackward(inputs.at(1)); // weight ctx->weight_index = ctx->SaveTensorForBackward(inputs.at(1)); // weight
} }
...@@ -79,7 +88,11 @@ Maybe<void> ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorT ...@@ -79,7 +88,11 @@ Maybe<void> ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorT
Maybe<void> ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads, Maybe<void> ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const { TensorTuple* in_grads) const {
in_grads->resize(2); if (ctx->has_bias) {
in_grads->resize(3);
} else {
in_grads->resize(2);
}
size_t num_spatial_dims = ctx->kernel_size.size(); size_t num_spatial_dims = ctx->kernel_size.size();
if (ctx->input_requires_grad) { if (ctx->input_requires_grad) {
const auto& weight = ctx->SavedTensors().at(ctx->weight_index); const auto& weight = ctx->SavedTensors().at(ctx->weight_index);
...@@ -94,6 +107,18 @@ Maybe<void> ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const Ten ...@@ -94,6 +107,18 @@ Maybe<void> ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const Ten
out_grads.at(0), input, num_spatial_dims, ctx->kernel_size, ctx->strides, out_grads.at(0), input, num_spatial_dims, ctx->kernel_size, ctx->strides,
ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format)); ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));
} }
if (ctx->bias_requires_grad) {
std::vector<int32_t> dim;
for (int i = 0; i < out_grads.at(0)->shape()->NumAxes(); ++i) {
if ((ctx->data_format == "channels_first" && i == 1)
|| (ctx->data_format == "channels_last"
&& i == out_grads.at(0)->shape()->NumAxes() - 1)) {
continue;
}
dim.push_back(i);
}
in_grads->at(2) = JUST(functional::ReduceSum(out_grads.at(0), dim, false));
}
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
......
...@@ -38,8 +38,14 @@ class Copy : public OpExprGradFunction<CopyCaptureState> { ...@@ -38,8 +38,14 @@ class Copy : public OpExprGradFunction<CopyCaptureState> {
Maybe<void> Capture(CopyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, Maybe<void> Capture(CopyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override { const AttrMap& attrs) const override {
ctx->device_type = JUST(inputs.at(0)->device())->type(); CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->device_id = JUST(inputs.at(0)->device())->device_id(); if (inputs[0]->is_global()) {
ctx->device_type = JUST(inputs[0]->parallel_desc())->device_tag();
ctx->device_id = 0; // global tensor only has one local device
} else {
ctx->device_type = JUST(inputs[0]->device())->type();
ctx->device_id = JUST(inputs[0]->device())->device_id();
}
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
......
...@@ -57,7 +57,7 @@ Maybe<void> CTCLoss::Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs ...@@ -57,7 +57,7 @@ Maybe<void> CTCLoss::Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs
ComposedAttrMap composed_attrs(attrs, base_attrs_); ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->max_target_length = JUST(composed_attrs.GetAttr<int64_t>("max_target_length")); ctx->max_target_length = JUST(composed_attrs.GetAttr<int64_t>("max_target_length"));
ctx->blank = JUST(composed_attrs.GetAttr<int32_t>("blank")); ctx->blank = JUST(composed_attrs.GetAttr<int64_t>("blank"));
ctx->zero_infinity = JUST(composed_attrs.GetAttr<bool>("zero_infinity")); ctx->zero_infinity = JUST(composed_attrs.GetAttr<bool>("zero_infinity"));
CHECK_EQ_OR_RETURN(inputs.size(), 4); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(inputs.size(), 4); // NOLINT(maybe-need-error-msg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment