Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_
#include "oneflow/core/graph/op_graph.h"
namespace oneflow {
namespace auto_parallel {
// Judge whether we need the same SBP for both producer and consumer
bool RequireSameSbp(const OpNode* consumer, const std::string& ibn);
} // namespace auto_parallel
} // namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_
......@@ -17,17 +17,25 @@ limitations under the License.
#include <memory>
#include <stack>
#include <queue>
#include "fmt/core.h"
#include "fmt/format.h"
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_arg.h"
#include "oneflow/core/framework/tensor_methods.h"
#include "oneflow/core/framework/tensor_util.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/global_param_grad_sync_mode.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/common/env_var/debug_mode.h"
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
namespace oneflow {
namespace one {
......@@ -75,20 +83,16 @@ bool IsReadyToRun(const std::vector<std::shared_ptr<AutogradMeta>>& out_meta_dat
Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) {
autograd::AutoGradMode mode(autograd_mode);
auto current_grad = JUST(autograd_meta->current_grad()->GetAccTensor({}));
auto current_grad = JUST(autograd_meta->current_grad_value());
if (!current_grad) { return Maybe<void>::Ok(); }
if (autograd_meta->acc_grad()) {
// Should not inplace accumulate grad. For example,
// >>> z = x + y
// >>> p = x / z
// >>> p.sum().backward()
//
// As we know that dx = dz + dp / z and dy = dz, so it will lead to wrong value
// for dy if dx is shared with dz.
const auto& output = JUST(functional::Add(autograd_meta->acc_grad(), current_grad, /*alpha=*/1,
/*inplace=*/autograd_meta->is_grad_acc_inplace()));
JUST(autograd_meta->set_acc_grad(output));
JUST(functional::Add(autograd_meta->acc_grad(), current_grad, /*alpha=*/1.0,
/*inplace=*/true));
} else {
// NOTE: acc_grad can not share data with current_grad, because accumulate acc_grad
// with inplace operation and it maybe change current_grad to get wrong result.
// See more details in https://github.com/Oneflow-Inc/oneflow/issues/8248
if (!LazyMode::is_enabled()) { current_grad = JUST(functional::Identity(current_grad)); }
JUST(autograd_meta->set_acc_grad(current_grad));
}
for (const auto& hook : autograd_meta->post_grad_accumulation_hooks()) {
......@@ -99,47 +103,50 @@ Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) {
return Maybe<void>::Ok();
}
Maybe<void> RawTorchConsistentTensor(const std::shared_ptr<one::Tensor>& tensor) {
Maybe<void> RawTouchGlobalTensor(const std::shared_ptr<one::Tensor>& tensor) {
// Do nothing.
return Maybe<void>::Ok();
}
static constexpr auto* TorchConsistentTensor =
DECORATE(&RawTorchConsistentTensor, CheckConsistentTensorMeta);
static constexpr auto* TouchGlobalTensor = DECORATE(&RawTouchGlobalTensor, CheckGlobalTensorMeta);
Maybe<void> CheckConsistentTensorsMeta(const TensorTuple& tensor_tuple) {
Maybe<void> CheckGlobalTensorsMeta(const TensorTuple& tensor_tuple) {
for (const auto& tensor : tensor_tuple) {
if (tensor->is_consistent()) { JUST(TorchConsistentTensor(tensor)); }
if (tensor->is_global() && tensor->is_eager()) { JUST(TouchGlobalTensor(tensor)); }
}
return Maybe<void>::Ok();
}
std::string GetDebugGraphFileName(const std::string& mode, const std::string& suffix) {
return fmt::format("autograd_{}_rank{}_suffix_graph.dot", mode, GlobalProcessCtx::Rank(), suffix);
}
} // namespace
Maybe<void> AutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs,
const TensorTuple& out_grads,
bool retain_graph,
bool create_graph) {
JUST(CheckConsistentTensorsMeta(outputs));
JUST(CheckConsistentTensorsMeta(out_grads));
DisableCheckConsistentTensorMetaScope disable_meta_check;
JUST(CheckGlobalTensorsMeta(outputs));
JUST(CheckGlobalTensorsMeta(out_grads));
DisableCheckGlobalTensorMetaScope disable_meta_check;
return RunBackwardAndSaveGrads4LeafTensor(outputs, out_grads, retain_graph, create_graph);
}
Maybe<TensorTuple> AutogradEngine::RunBackwardAndReturnInputsTensorGradIf(
const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads,
bool retain_graph, bool create_graph) {
JUST(CheckConsistentTensorsMeta(outputs));
JUST(CheckConsistentTensorsMeta(inputs));
JUST(CheckConsistentTensorsMeta(out_grads));
DisableCheckConsistentTensorMetaScope disable_meta_check;
JUST(CheckGlobalTensorsMeta(outputs));
JUST(CheckGlobalTensorsMeta(inputs));
JUST(CheckGlobalTensorsMeta(out_grads));
DisableCheckGlobalTensorMetaScope disable_meta_check;
return RunBackwardAndReturnInputsTensorGrad(outputs, inputs, out_grads, retain_graph,
create_graph);
}
Maybe<void> FunctionNode::AccGrad4RetainGradTensor() {
Maybe<void> FunctionNode::AccGrad4RetainGradTensor(bool create_graph) {
for (const std::shared_ptr<AutogradMeta>& out : output_meta_data_) {
if (out->retain_grad()) { JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/false)); }
if (out->retain_grad()) { JUST(CopyOrAccGrad(out.get(), create_graph)); }
}
return Maybe<void>::Ok();
}
......@@ -149,17 +156,18 @@ Maybe<void> FunctionNode::AccGrad4LeafTensor(bool create_graph) {
auto& out = output_meta_data_[i];
if (out->is_leaf() && out->requires_grad()) {
JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/false));
JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/create_graph));
// control acc_grad to do boxing conditionally
const auto& acc_grad = out->acc_grad();
if (GlobalGradSyncMode::is_enabled() && acc_grad->is_consistent()) {
if (!LazyMode::is_enabled() && GlobalGradSyncMode::is_enabled() && acc_grad->is_global()
&& acc_grad->is_eager()) {
auto& tensor_info = output_tensor_infos_[i];
const auto& placement = JUST(tensor_info.placement());
const auto& nd_sbp = JUST(tensor_info.sbp());
JUST(out->set_acc_grad(
JUST(functional::ToConsistent(acc_grad, placement, *JUST(GetSbpList(nd_sbp)),
GetNoneSbpList(), /* check_meta */ false))));
JUST(functional::ToGlobal(acc_grad, placement, *JUST(GetSbpList(nd_sbp)),
GetNoneSbpList(), /* check_meta */ false, /*copy=*/false))));
}
}
}
......@@ -182,22 +190,30 @@ Maybe<bool> FunctionNode::Apply(bool create_graph) {
TensorTuple output_grads(output_meta_data_.size());
for (int i = 0; i < output_meta_data_.size(); ++i) {
if (output_meta_data_.at(i)->current_grad()->Empty()) {
output_grads.at(i) = JUST(output_tensor_infos_.at(i).zeros());
// Only initialize out_grads for those requires_grad outputs
if (output_meta_data_[i]->requires_grad()) {
output_grads[i] = JUST(output_tensor_infos_[i].zeros());
}
} else {
const auto& hooks = JUST(oneflow::VectorAt(output_meta_data_, i))->hooks();
JUST(oneflow::VectorAt(output_grads, i)) =
JUST(JUST(oneflow::VectorAt(output_meta_data_, i))->current_grad()->GetAccTensor(hooks));
JUST(JUST(oneflow::VectorAt(output_meta_data_, i))->current_grad_value());
}
}
JUST(backward_fn_->body(output_grads, &input_grads, create_graph));
for (int i = 0; i < input_meta_data_.size(); ++i) {
if (JUST(VectorAt(input_grads, i))) {
CHECK_NOTNULL_OR_RETURN(input_meta_data_.at(i))
CHECK_NOTNULL_OR_RETURN(input_meta_data_[i])
<< name_
<< " calculate grad for tensor which requires_grad is False. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possible";
JUST(input_meta_data_.at(i)->current_grad()->PushPartialTensor(input_grads.at(i)));
JUST(input_meta_data_[i]->current_grad()->PushPartialTensor(JUST(VectorAt(input_grads, i))));
} else {
CHECK_OR_RETURN(!input_meta_data_[i])
<< name() << "'s input[" << i
<< "] need calculate grad but got nullptr. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possible;";
}
}
return true;
......@@ -247,15 +263,64 @@ GraphTask::GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_
for (const auto& out_tensor : outputs) {
FunctionNode* node = out_tensor->mut_grad_fn_node().get();
roots_.emplace_back(node);
dependencies_.insert(std::make_pair(node, 0));
}
}
Maybe<void> GraphTask::WriteGraphToDotFile(const std::string& file_name) const {
auto ExecInfoToDotString = [](const ExecInfo& exec_info) -> std::string {
std::stringstream ss;
ss << "ExecInfo{\\l";
ss << "\tdependencies: " << exec_info.dependencies << "\\l";
ss << "\tneed_execute: " << exec_info.need_execute << "\\l";
if (exec_info.capture_indices) {
ss << "\tcapture_indices: [";
for (const auto& out_idx_and_capture_idx : *exec_info.capture_indices) {
ss << out_idx_and_capture_idx.second << ", ";
}
ss << "]\\l";
}
ss << "}\\l";
return ss.str();
};
auto log_stream = TeePersistentLogStream::Create(file_name);
std::vector<std::string> lines;
lines.emplace_back("digraph AutogradTaskGraph {");
lines.emplace_back("\tmargin=\"1.5\";");
lines.emplace_back("\tnode [shape=box];");
for (auto iter = grad_fn2exec_info_.begin(); iter != grad_fn2exec_info_.end(); ++iter) {
const FunctionNode* node = iter->first;
const ExecInfo& exec_info = iter->second;
// write label attribute
std::string node_color = "black";
if (exec_info.dependencies == 0 && exec_info.need_execute) { // start node
node_color = "red";
} else if (exec_info.need_execute && exec_info.capture_indices) { // end node
node_color = "green";
}
lines.emplace_back(fmt::format(
"\t\"{}\" [label=\"{}\\l{}\\l{}\", color={}];", static_cast<const void*>(node),
node->name(), static_cast<const void*>(node), ExecInfoToDotString(exec_info), node_color));
// write edge
for (const auto& next_fn : node->next_functions()) {
lines.emplace_back(fmt::format("\t\"{}\" -> \"{}\";", static_cast<const void*>(node),
static_cast<const void*>(next_fn.get())));
}
}
lines.emplace_back("}");
log_stream << fmt::format("{}", fmt::join(lines, "\n"));
log_stream->Flush();
return Maybe<void>::Ok();
}
// Computes the number of dependencies for each FunctionNode
Maybe<void> GraphTask::ComputeDependencies() {
HashSet<FunctionNode*> seen;
std::stack<FunctionNode*> stack;
for (FunctionNode* node : roots_) { stack.push(node); }
for (FunctionNode* node : roots_) {
stack.push(node);
grad_fn2exec_info_[node].need_execute = true;
}
while (!stack.empty()) {
FunctionNode* node = stack.top();
......@@ -263,7 +328,9 @@ Maybe<void> GraphTask::ComputeDependencies() {
if (/*bool has_seen=*/!seen.insert(node).second) { continue; }
for (const auto& next_grad_fn : node->next_functions()) {
FunctionNode* next_node = next_grad_fn.get();
dependencies_[next_node] += 1;
ExecInfo& exec_info = grad_fn2exec_info_[next_node];
exec_info.dependencies += 1;
exec_info.need_execute = true;
if (seen.find(next_node) == seen.end()) { stack.push(next_node); }
}
}
......@@ -288,9 +355,17 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs
}
};
for (const auto& input : inputs) {
CHECK_NOTNULL_OR_RETURN(input->mut_grad_fn_node().get());
need_execute_.insert(input->mut_grad_fn_node().get());
// initialize all variable to capture grad for input tensors
captured_grads_ = std::make_shared<TensorTuple>(inputs.size());
for (int idx = 0; idx < inputs.size(); idx++) {
const auto& input = inputs[idx];
CHECK_NOTNULL_OR_RETURN(input->mut_grad_fn_node().get()); // NOLINT(maybe-need-error-msg)
ExecInfo& exec_info = grad_fn2exec_info_[input->mut_grad_fn_node().get()];
exec_info.need_execute = true;
if (!exec_info.capture_indices) {
exec_info.capture_indices = std::make_unique<std::vector<std::pair<size_t, size_t>>>();
}
exec_info.capture_indices->emplace_back(std::make_pair(input->get_grad_fn_output_index(), idx));
}
HashSet<FunctionNode*> seen;
......@@ -305,18 +380,17 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs
continue;
}
if (FunctionNode* node = frame.GetNextFunction()) {
dependencies_[node] += 1;
grad_fn2exec_info_[node].dependencies += 1;
if (seen.find(node) == seen.end()) {
stack.push(NodeFrame(node));
continue; // recurse
}
} else {
bool need_execute =
grad_fn2exec_info_[frame.node_].need_execute |=
std::any_of(frame.node_->next_functions().begin(), frame.node_->next_functions().end(),
[&](const std::shared_ptr<FunctionNode>& fn) {
return need_execute_.find(fn.get()) != need_execute_.end();
return grad_fn2exec_info_[fn.get()].need_execute;
});
if (need_execute) { need_execute_.insert(frame.node_); }
seen.insert(frame.node_);
stack.pop();
}
......@@ -327,26 +401,38 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs
Maybe<void> GraphTask::Apply(bool save_grad_for_leaf) {
std::queue<FunctionNode*> queue;
for (FunctionNode* node : roots_) {
if (dependencies_[node] == 0) { queue.push(node); }
if (grad_fn2exec_info_[node].dependencies == 0) { queue.push(node); }
}
while (!queue.empty()) {
FunctionNode* node = queue.front();
queue.pop();
if (!need_execute_.empty() && need_execute_.find(node) == need_execute_.end()) {
auto& exec_info = grad_fn2exec_info_[node];
if (!exec_info.need_execute) {
node->ReleaseOutTensorArgs();
continue;
}
BackwardPassScopeGuard backward_guard(node->scope());
if (/*bool not_ready_to_apply=*/!(JUST(node->Apply(create_graph_)))) { continue; }
if (exec_info.capture_indices) {
CHECK_NOTNULL_OR_RETURN(captured_grads_.get()) << "captured grads in GraphTask is nullptr";
for (const auto& out_idx_and_capture_idx : *exec_info.capture_indices) {
JUST(VectorAt(*captured_grads_, out_idx_and_capture_idx.second)) =
JUST(JUST(VectorAt(node->output_meta_data_, out_idx_and_capture_idx.first))
->current_grad_value());
}
}
if (save_grad_for_leaf) { JUST(node->AccGrad4LeafTensor(create_graph_)); }
JUST(node->AccGrad4RetainGradTensor());
JUST(node->AccGrad4RetainGradTensor(create_graph_));
node->ReleaseOutTensorArgs();
if (!retain_graph_) { node->ReleaseData(); }
for (const auto& next_grad_fn : node->next_functions()) {
FunctionNode* next_node = next_grad_fn.get();
dependencies_[next_node] -= 1;
if (dependencies_[next_node] == 0) { queue.push(next_node); }
int32_t& dependencies = grad_fn2exec_info_[next_node].dependencies;
dependencies -= 1;
if (dependencies == 0) { queue.push(next_node); }
}
}
return Maybe<void>::Ok();
......@@ -361,6 +447,10 @@ Maybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor
}
GraphTask graph_task(outputs, retain_graph, create_graph);
JUST(graph_task.ComputeDependencies());
if (IsInDebugMode()) {
JUST(
graph_task.WriteGraphToDotFile(GetDebugGraphFileName("backward", std::to_string(clock()))));
}
JUST(graph_task.Apply(/*save_grad_for_leaf=*/true));
return Maybe<void>::Ok();
}
......@@ -368,34 +458,23 @@ Maybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor
Maybe<TensorTuple> GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad(
const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads,
bool retain_graph, bool create_graph) {
std::shared_ptr<TensorTuple> input_current_grad = std::make_shared<TensorTuple>(inputs.size());
GraphTask graph_task(outputs, retain_graph, create_graph);
std::vector<bool> ori_retain_grad(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
ori_retain_grad.at(i) = inputs.at(i)->retain_grad();
JUST(inputs.at(i)->set_retain_grad(true));
}
for (int i = 0; i < outputs.size(); ++i) {
JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i)));
}
GraphTask graph_task(outputs, retain_graph, create_graph);
JUST(graph_task.ComputeDependenciesAndPruneNode(inputs));
JUST(graph_task.Apply(/*save_grad_for_leaf=*/false));
// Gets input grads and resume retain_grad
for (int i = 0; i < inputs.size(); ++i) {
input_current_grad->at(i) = JUST(inputs.at(i)->acc_grad());
if (!ori_retain_grad.at(i)) {
JUST(inputs.at(i)->set_acc_grad(nullptr));
JUST(inputs.at(i)->set_retain_grad(false));
}
if (IsInDebugMode()) {
JUST(graph_task.WriteGraphToDotFile(GetDebugGraphFileName("grad", std::to_string(clock()))));
}
return input_current_grad;
JUST(graph_task.Apply(/*save_grad_for_leaf=*/false));
return graph_task.GetCapturedGrads();
}
Maybe<FunctionNode> GraphAutogradEngine::AddNode(
const std::string& name, const std::shared_ptr<BackwardFunction>& backward_fn,
const TensorTuple& inputs, TensorTuple* outputs) {
OF_PROFILER_RANGE_PUSH("AddAccumulateFunctionNode");
// Firstly push function_node of tensor in stack which is leaf and requires_grad
for (const std::shared_ptr<Tensor>& in_tensor : inputs) {
if (in_tensor->is_leaf() && in_tensor->requires_grad()) {
......@@ -403,11 +482,17 @@ Maybe<FunctionNode> GraphAutogradEngine::AddNode(
}
}
OF_PROFILER_RANGE_POP();
OF_PROFILER_RANGE_PUSH("set_grad_fn_node");
std::shared_ptr<FunctionNode> func_node =
GraphFunctionNode::New(name, backward_fn, inputs, *outputs);
for (const std::shared_ptr<Tensor>& out_tensor : *outputs) {
for (int i = 0; i < outputs->size(); ++i) {
const std::shared_ptr<Tensor>& out_tensor = JUST(VectorAt(*outputs, i));
out_tensor->set_grad_fn_node(func_node);
out_tensor->set_grad_fn_output_index(i);
}
if (LazyMode::is_enabled()) { func_node->set_scope(JUST(GetCurrentScope())); }
OF_PROFILER_RANGE_POP();
return func_node;
}
......@@ -423,6 +508,10 @@ Maybe<void> AddAccumulateFunctionNode(const std::shared_ptr<Tensor>& tensor) {
backward_fn->status = []() { return false; };
tensor->set_grad_fn_node(GraphFunctionNode::New(
"accumulate_grad", backward_fn, /*inputs=*/TensorTuple{}, /*outputs*/ TensorTuple{tensor}));
tensor->set_grad_fn_output_index(0);
if (LazyMode::is_enabled()) {
tensor->mut_grad_fn_node()->set_scope(JUST(GetTensorScope(tensor)));
}
return Maybe<void>::Ok();
}
......
......@@ -17,12 +17,15 @@ limitations under the License.
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
#include <functional>
#include <list>
#include <vector>
#include <memory>
#include <functional>
#include "oneflow/core/common/util.h"
#include <vector>
#include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/scope_util.h"
#include "oneflow/core/job/lazy_mode.h"
namespace oneflow {
......@@ -45,7 +48,7 @@ class FunctionNode {
Maybe<bool> Apply(bool create_graph);
Maybe<void> AccGrad4LeafTensor(bool create_graph);
Maybe<void> AccGrad4RetainGradTensor();
Maybe<void> AccGrad4RetainGradTensor(bool create_graph);
void ReleaseOutTensorArgs();
// Releases the eventual c++ std::function for backward if retain_graph=False to avoid calling
// `Apply` in second time
......@@ -56,10 +59,14 @@ class FunctionNode {
}
const std::string& name() const { return name_; }
const std::shared_ptr<Scope>& scope() const { return scope_; }
void set_scope(const std::shared_ptr<Scope>& scope) { scope_ = scope; }
protected:
friend class GraphTask;
explicit FunctionNode(const std::string& name,
const std::shared_ptr<BackwardFunction>& backward_fn)
: name_(name), backward_fn_(backward_fn) {}
: name_(name), backward_fn_(backward_fn), scope_(nullptr) {}
const std::string name_;
std::vector<std::shared_ptr<FunctionNode>> next_functions_;
......@@ -70,6 +77,9 @@ class FunctionNode {
// Actual backward function builds in `AutogradInterpreter` to calculate one backward op
std::shared_ptr<BackwardFunction> backward_fn_;
// The execution scope
std::shared_ptr<Scope> scope_;
};
class AutogradEngine {
......@@ -130,13 +140,26 @@ class GraphTask final {
Maybe<void> ComputeDependencies();
Maybe<void> ComputeDependenciesAndPruneNode(const TensorTuple& inputs);
Maybe<void> Apply(bool save_grad_for_leaf);
std::shared_ptr<TensorTuple> GetCapturedGrads() const { return captured_grads_; }
Maybe<void> WriteGraphToDotFile(const std::string& file_name) const;
private:
class ExecInfo {
public:
ExecInfo() = default;
int32_t dependencies = 0;
bool need_execute = false;
// Used in autograd.grad interface, to record which grad of tensor will be captured.
// The pair means: <output index of this Node, the index of captured_grads_ to be saved>
std::unique_ptr<std::vector<std::pair<size_t, size_t>>> capture_indices;
};
bool retain_graph_;
bool create_graph_;
std::vector<FunctionNode*> roots_;
HashMap<FunctionNode*, int> dependencies_;
HashSet<FunctionNode*> need_execute_;
HashMap<FunctionNode*, ExecInfo> grad_fn2exec_info_;
std::shared_ptr<TensorTuple> captured_grads_;
};
class GraphAutogradEngine final : public AutogradEngine {
......
......@@ -25,9 +25,12 @@ namespace oneflow {
namespace one {
TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {
if (TRY(tensor.device()).IsOk()) { device_ = CHECK_JUST(tensor.device()); }
if (TRY(tensor.parallel_desc()).IsOk()) { parallel_desc_ = CHECK_JUST(tensor.parallel_desc()); }
if (TRY(tensor.nd_sbp()).IsOk()) { nd_sbp_ = CHECK_JUST(tensor.nd_sbp()); }
if (tensor.is_global()) {
parallel_desc_ = CHECK_JUST(tensor.parallel_desc());
nd_sbp_ = CHECK_JUST(tensor.nd_sbp());
} else {
device_ = CHECK_JUST(tensor.device());
}
}
Maybe<const std::vector<Symbol<SbpParallel>>&> GetSbpTuple(Symbol<NdSbp> nd_sbp) {
......@@ -52,7 +55,7 @@ Maybe<Tensor> TensorInfo::zeros() const {
const auto& parallel_desc = JUST(parallel_desc_);
const auto& nd_sbp = JUST(nd_sbp_);
const auto& sbp_tuple = JUST(GetSbpTuple(nd_sbp));
return functional::ConsistentConstant(*shape_.get(), 0, dtype_, parallel_desc, sbp_tuple);
return functional::GlobalConstant(*shape_.get(), 0, dtype_, parallel_desc, sbp_tuple);
}
}
......@@ -60,18 +63,26 @@ AutogradMeta::AutogradMeta(bool requires_grad, bool is_leaf)
: is_leaf_(is_leaf),
requires_grad_(requires_grad),
retain_grad_(false),
is_grad_acc_inplace_(false),
current_grad_(new TensorArg) {}
Maybe<void> AutogradMeta::set_acc_grad(const std::shared_ptr<Tensor>& grad) {
if (const auto& static_zeros_tensor = std::dynamic_pointer_cast<StaticZerosTensor>(grad)) {
acc_grad_ = JUST(static_zeros_tensor->AsMirroredTensor());
acc_grad_ = JUST(static_zeros_tensor->AsLocalTensor());
} else {
acc_grad_ = grad;
}
return Maybe<void>::Ok();
}
Maybe<Tensor> AutogradMeta::current_grad_value() const {
std::shared_ptr<Tensor> res = JUST(current_grad_->GetAccTensor());
for (const auto& hook : hooks_) {
const auto& new_tensor = hook(res);
if (new_tensor) { res = new_tensor; }
}
return res;
}
} // namespace one
} // namespace oneflow
......@@ -36,7 +36,7 @@ namespace one {
class Tensor;
class TensorArg;
class MirroredTensor;
class LocalTensor;
class AutogradMeta final {
public:
......@@ -46,7 +46,8 @@ class AutogradMeta final {
// Getters
const std::shared_ptr<Tensor>& acc_grad() const { return acc_grad_; }
const std::shared_ptr<TensorArg>& current_grad() const { return current_grad_; }
bool is_grad_acc_inplace() const { return is_grad_acc_inplace_; }
// get current grad processed by hooks
Maybe<Tensor> current_grad_value() const;
bool requires_grad() const { return requires_grad_; }
bool is_leaf() const { return is_leaf_; }
bool retain_grad() const { return retain_grad_; }
......@@ -59,7 +60,6 @@ class AutogradMeta final {
// Setters
Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad);
std::shared_ptr<Tensor> mut_acc_grad() { return acc_grad_; }
void set_is_grad_acc_inplace(bool is_inplace) { is_grad_acc_inplace_ = is_inplace; }
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
void set_retain_grad(bool retain_grad) { retain_grad_ = retain_grad; }
void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; }
......@@ -77,10 +77,6 @@ class AutogradMeta final {
// Only meaningful on non_leaf Tensors (must be false otherwise)
bool retain_grad_;
// Control whether grad accumulation is inplace. Don't change it
// unless you know what you are doing
bool is_grad_acc_inplace_;
std::shared_ptr<Tensor> acc_grad_;
std::shared_ptr<TensorArg> current_grad_;
std::vector<Hook> hooks_;
......@@ -104,8 +100,8 @@ class TensorInfo final {
std::shared_ptr<const Shape> shape_;
Symbol<DType> dtype_;
Optional<Symbol<Device>> device_; // for local tensor
Optional<Symbol<ParallelDesc>> parallel_desc_; // for consistent tensor
Optional<Symbol<NdSbp>> nd_sbp_; // for consistent tensor
Optional<Symbol<ParallelDesc>> parallel_desc_; // for global tensor
Optional<Symbol<NdSbp>> nd_sbp_; // for global tensor
};
} // namespace one
......
......@@ -108,6 +108,50 @@ class GeLU : public BaseActivation {
}
};
class FastGeLU : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::FastGeluGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
struct QuickGeluCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};
class QuickGeLU : public OpExprGradFunction<QuickGeluCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(QuickGeluCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const QuickGeluCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::QuickGeluGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};
class HardSigmoid : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
......@@ -558,6 +602,8 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("prelu", PReLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("threshold", Threshold);
REGISTER_OP_EXPR_GRAD_FUNCTION("softplus", Softplus);
REGISTER_OP_EXPR_GRAD_FUNCTION("softshrink", SoftShrink);
REGISTER_OP_EXPR_GRAD_FUNCTION("fast_gelu", FastGeLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("quick_gelu", QuickGeLU);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct AdaptiveMaxPoolCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};
class AdaptiveMaxPoolNdGrad : public OpExprGradFunction<AdaptiveMaxPoolCaptureState> {
public:
using OpExprGradFunction<AdaptiveMaxPoolCaptureState>::Init;
Maybe<void> Init(const OpExpr& op, const int& ndims);
Maybe<void> Capture(AdaptiveMaxPoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const AdaptiveMaxPoolCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
int32_t ndims_ = 0;
};
Maybe<void> AdaptiveMaxPoolNdGrad::Init(const OpExpr& op, const int& ndims) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
ndims_ = ndims;
return Maybe<void>::Ok();
}
Maybe<void> AdaptiveMaxPoolNdGrad::Capture(AdaptiveMaxPoolCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(outputs.at(1));
return Maybe<void>::Ok();
}
Maybe<void> AdaptiveMaxPoolNdGrad::Apply(const AdaptiveMaxPoolCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(1);
in_grads->resize(1);
in_grads->at(0) = JUST(functional::AdaptiveMaxPoolNdGrad(x, out_grads.at(0), index, ndims_));
return Maybe<void>::Ok();
}
class AdaptiveMaxPool1dGrad final : public AdaptiveMaxPoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 1); }
};
class AdaptiveMaxPool2dGrad final : public AdaptiveMaxPoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 2); }
};
class AdaptiveMaxPool3dGrad final : public AdaptiveMaxPoolNdGrad {
public:
Maybe<void> Init(const OpExpr& op) override { return AdaptiveMaxPoolNdGrad::Init(op, 3); }
};
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool1d", AdaptiveMaxPool1dGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool2d", AdaptiveMaxPool2dGrad);
REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool3d", AdaptiveMaxPool3dGrad);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
enum class AmpIdentityType {
kWhite = 0,
kBlack,
};
struct AmpIdentityCaptureState : public AutoGradCaptureState {};
template<AmpIdentityType type>
class AmpIdentityGrad : public OpExprGradFunction<AmpIdentityCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
return Maybe<void>::Ok();
}
Maybe<void> Capture(AmpIdentityCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
return Maybe<void>::Ok();
}
Maybe<void> Apply(const AmpIdentityCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(1);
if (type == AmpIdentityType::kWhite) {
(*in_grads)[0] = JUST(functional::AmpWhiteIdentity(out_grads[0]));
} else if (type == AmpIdentityType::kBlack) {
(*in_grads)[0] = JUST(functional::AmpBlackIdentity(out_grads[0]));
} else {
(*in_grads)[0] = out_grads[0];
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("amp_white_identity", AmpIdentityGrad<AmpIdentityType::kWhite>);
REGISTER_OP_EXPR_GRAD_FUNCTION("amp_black_identity", AmpIdentityGrad<AmpIdentityType::kBlack>);
} // namespace one
} // namespace oneflow
......@@ -22,9 +22,9 @@ namespace oneflow {
namespace one {
struct AsStridedCaptureState : public AutoGradCaptureState {
std::vector<int32_t> size;
std::vector<int32_t> stride;
int32_t storage_offset = 0;
std::vector<int64_t> size;
std::vector<int64_t> stride;
int64_t storage_offset = 0;
bool requires_grad = false;
};
......@@ -55,9 +55,9 @@ Maybe<void> AsStrided::Capture(AsStridedCaptureState* ctx, const TensorTuple& in
ctx->SaveTensorForBackward(inputs.at(0));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("size"));
ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("stride"));
ctx->storage_offset = JUST(composed_attrs.GetAttr<int32_t>("storage_offset"));
ctx->size = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("size"));
ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("stride"));
ctx->storage_offset = JUST(composed_attrs.GetAttr<int64_t>("storage_offset"));
return Maybe<void>::Ok();
}
......@@ -67,9 +67,9 @@ Maybe<void> AsStrided::Apply(const AsStridedCaptureState* ctx, const TensorTuple
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& input = ctx->SavedTensors().at(0);
std::vector<int32_t> size = ctx->size;
std::vector<int32_t> stride = ctx->stride;
int32_t storage_offset = ctx->storage_offset;
std::vector<int64_t> size = ctx->size;
std::vector<int64_t> stride = ctx->stride;
int64_t storage_offset = ctx->storage_offset;
in_grads->at(0) =
JUST(functional::AsStridedGrad(out_grads.at(0), input, size, stride, storage_offset));
......
......@@ -20,7 +20,9 @@ namespace oneflow {
namespace one {
struct BinaryCrossEntropyCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
bool has_weight = false;
};
class BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureState> {
......@@ -30,46 +32,42 @@ class BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureSt
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> BinaryCrossEntropy::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropy::Init(const OpExpr& op) { return Maybe<void>::Ok(); }
Maybe<void> BinaryCrossEntropy::Capture(BinaryCrossEntropyCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(inputs.size() >= 2 && inputs.size() <= 3); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs[0]->requires_grad();
ctx->target_requires_grad = inputs[1]->requires_grad();
ctx->has_weight = inputs.size() == 3;
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target
if (inputs.size() == 3) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight
ctx->SaveTensorForBackward(inputs[0]); // input
ctx->SaveTensorForBackward(inputs[1]); // target
if (ctx->has_weight) {
ctx->SaveTensorForBackward(inputs[2]); // weight
}
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropy::Apply(const BinaryCrossEntropyCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0);
const auto& input = ctx->SavedTensors().at(0);
const auto& target = ctx->SavedTensors().at(1);
in_grads->resize(ctx->SavedTensors().size());
CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),
2 + ctx->has_weight); // NOLINT(maybe-need-error-msg)
in_grads->resize(2 + ctx->has_weight);
const auto& dy = out_grads[0];
const auto& input = ctx->SavedTensors()[0];
const auto& target = ctx->SavedTensors()[1];
const auto& weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[2]) : NullOpt;
if (ctx->SavedTensors().size() == 3) {
const auto& weight = ctx->SavedTensors().at(2);
in_grads->at(0) = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, weight));
} else {
in_grads->at(0) = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, NullOpt));
if (ctx->input_requires_grad) {
(*in_grads)[0] = JUST(functional::BinaryCrossEntropyLossGrad(dy, input, target, weight));
}
if (ctx->target_requires_grad) {
(*in_grads)[1] = JUST(functional::BinaryCrossEntropyLossTargetGrad(dy, input, target, weight));
}
return Maybe<void>::Ok();
}
......
......@@ -20,7 +20,9 @@ namespace oneflow {
namespace one {
struct BinaryCrossEntropyWithLogitsCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
bool has_weight = false;
bool has_pos_weight = false;
};
......@@ -47,53 +49,51 @@ Maybe<void> BinaryCrossEntropyWithLogits::Capture(BinaryCrossEntropyWithLogitsCa
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(inputs.size() >= 2 && inputs.size() <= 4); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs[0]->requires_grad();
ctx->target_requires_grad = inputs[1]->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->has_pos_weight = JUST(composed_attrs.GetAttr<bool>("has_pos_weight"));
ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target
ctx->has_weight = inputs.size() == 4 || (inputs.size() == 3 && !ctx->has_pos_weight);
ctx->SaveTensorForBackward(inputs[0]); // input
ctx->SaveTensorForBackward(inputs[1]); // target
if (inputs.size() == 3) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight or pos_weight
ctx->SaveTensorForBackward(inputs[2]); // weight or pos_weight
}
if (inputs.size() == 4) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight
ctx->SaveTensorForBackward(inputs.at(3)); // pos_weight
ctx->SaveTensorForBackward(inputs[2]); // weight
ctx->SaveTensorForBackward(inputs[3]); // pos_weight
}
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogits::Apply(const BinaryCrossEntropyWithLogitsCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads.at(0);
const auto& input = ctx->SavedTensors().at(0);
const auto& target = ctx->SavedTensors().at(1);
CHECK_EQ_OR_RETURN(ctx->SavedTensors().size(),
2 + ctx->has_weight + ctx->has_pos_weight); // NOLINT(maybe-need-error-msg)
const auto& dy = out_grads[0];
const auto& input = ctx->SavedTensors()[0];
const auto& target = ctx->SavedTensors()[1];
in_grads->resize(ctx->SavedTensors().size());
if (ctx->SavedTensors().size() == 3) {
if (ctx->has_pos_weight) {
const auto& pos_weight = ctx->SavedTensors().at(2);
in_grads->at(0) = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, NullOpt, pos_weight));
} else {
const auto& weight = ctx->SavedTensors().at(2);
in_grads->at(0) = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, NullOpt));
}
} else if (ctx->SavedTensors().size() == 4) {
const auto& weight = ctx->SavedTensors().at(2);
const auto& pos_weight = ctx->SavedTensors().at(3);
in_grads->at(0) = JUST(
size_t pos_weight_index = ctx->has_weight ? 3 : 2;
auto weight = ctx->has_weight ? Optional<one::Tensor>(ctx->SavedTensors()[2]) : NullOpt;
auto pos_weight =
ctx->has_pos_weight ? Optional<one::Tensor>(ctx->SavedTensors()[pos_weight_index]) : NullOpt;
if (ctx->input_requires_grad) {
(*in_grads)[0] = JUST(
functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, weight, pos_weight));
} else {
in_grads->at(0) =
JUST(functional::BinaryCrossEntropyWithLogitsLossGrad(dy, input, target, NullOpt, NullOpt));
}
if (ctx->target_requires_grad) {
(*in_grads)[1] = JUST(functional::BinaryCrossEntropyWithLogitsLossTargetGrad(
dy, input, target, weight, pos_weight));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("binary_cross_entropy_with_logits", BinaryCrossEntropyWithLogits);
......
......@@ -21,8 +21,8 @@ namespace oneflow {
namespace one {
struct BinaryCrossEntropyWithLogitsReduceMeanCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool has_pos_weight = false;
bool input_requires_grad = false;
bool target_requires_grad = false;
};
class BinaryCrossEntropyWithLogitsReduceMean
......@@ -34,25 +34,19 @@ class BinaryCrossEntropyWithLogitsReduceMean
const AttrMap& attrs) const override;
Maybe<void> Apply(const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be null. ";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture(
BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = JUST(VectorAt(inputs, 0))->requires_grad();
ctx->target_requires_grad = JUST(VectorAt(inputs, 1))->requires_grad();
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0))); // input
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target
return Maybe<void>::Ok();
......@@ -61,14 +55,20 @@ Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture(
Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Apply(
const BinaryCrossEntropyWithLogitsReduceMeanCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "out_grads size should be equal to 1. ";
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& dy = JUST(VectorAt(out_grads, 0));
const auto& input = JUST(VectorAt(ctx->SavedTensors(), 0));
const auto& target = JUST(VectorAt(ctx->SavedTensors(), 1));
in_grads->resize(ctx->SavedTensors().size());
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossGrad(dy, input, target));
in_grads->resize(2);
if (ctx->input_requires_grad) {
(*in_grads)[0] =
JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossGrad(dy, input, target));
}
if (ctx->target_requires_grad) {
(*in_grads)[1] =
JUST(functional::BinaryCrossEntropyWithLogitsReduceMeanLossTargetGrad(dy, input, target));
}
return Maybe<void>::Ok();
}
......
......@@ -232,13 +232,12 @@ class BroadcastPow : public BroadcastBinaryGrad {
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& z = ctx->SavedTensors().at(ctx->z_index);
in_grads->resize(2);
if (ctx->x_requires_grad) {
in_grads->at(0) = JUST(functional::BroadcastPowXGrad(out_grads.at(0), x, y, z));
(*in_grads)[0] = JUST(functional::BroadcastPowXGrad(x, y, out_grads[0]));
}
if (ctx->y_requires_grad) {
in_grads->at(1) = JUST(functional::BroadcastPowYGrad(out_grads.at(0), x, y, z));
(*in_grads)[1] = JUST(functional::BroadcastPowYGrad(x, y, out_grads[0]));
}
return Maybe<void>::Ok();
}
......@@ -246,9 +245,8 @@ class BroadcastPow : public BroadcastBinaryGrad {
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
ctx->x_index = ctx->SaveTensorForBackward(inputs.at(0));
ctx->y_index = ctx->SaveTensorForBackward(inputs.at(1));
ctx->z_index = ctx->SaveTensorForBackward(outputs.at(0));
ctx->x_index = ctx->SaveTensorForBackward(inputs[0]);
ctx->y_index = ctx->SaveTensorForBackward(inputs[1]);
return Maybe<void>::Ok();
}
};
......@@ -348,5 +346,80 @@ class BroadcastMaximum : public BroadcastMinMax {
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_minimum", BroadcastMinimum);
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_maximum", BroadcastMaximum);
class BroadcastFMod : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const BroadcastBinaryCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& out_shape = *(JUST(VectorAt(out_grads, 0))->shape());
in_grads->resize(2);
if (ctx->x_requires_grad || ctx->y_requires_grad) {
const auto& x = JUST(VectorAt(ctx->SavedTensors(), ctx->x_index));
const auto& y = JUST(VectorAt(ctx->SavedTensors(), ctx->y_index));
auto broad_x_ = x;
auto broad_y_ = y;
if (ctx->broadcast_x) {
const auto& x_shape = *(x->shape());
const Shape& left_extended_x_shape =
CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes());
if (left_extended_x_shape == out_shape) {
broad_x_ = JUST(functional::ReshapeLike(x, JUST(VectorAt(out_grads, 0))));
} else {
const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> x_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_x_ = JUST(functional::BroadcastLike(x, JUST(VectorAt(out_grads, 0)), x_axis));
}
}
if (ctx->broadcast_y) {
const auto& y_shape = *(y->shape());
const Shape& left_extended_y_shape =
CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes());
if (left_extended_y_shape == out_shape) {
broad_y_ = JUST(functional::ReshapeLike(y, JUST(VectorAt(out_grads, 0))));
} else {
const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> y_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_y_ = JUST(functional::BroadcastLike(y, JUST(VectorAt(out_grads, 0)), y_axis));
}
}
if (ctx->x_requires_grad) {
if (ctx->broadcast_x) {
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::BroadcastReduceSumLike(JUST(VectorAt(out_grads, 0)), x));
} else {
JUST(VectorAt(*in_grads, 0)) = JUST(VectorAt(out_grads, 0));
}
}
if (ctx->y_requires_grad) {
auto result = JUST(functional::TruncDiv(broad_x_, broad_y_));
result = JUST(functional::Mul(JUST(VectorAt(out_grads, 0)), result));
JUST(functional::ScalarMul(result, Scalar(-1.f), true));
if (ctx->broadcast_y) {
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(result, y));
} else {
in_grads->at(1) = result;
}
}
}
return Maybe<void>::Ok();
}
protected:
Maybe<void> SaveTensorForBackward(BroadcastBinaryCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs) const override {
if (ctx->x_requires_grad && ctx->broadcast_x) {
ctx->x_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));
}
if (ctx->y_requires_grad) {
ctx->x_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 0)));
ctx->y_index = ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1)));
}
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_fmod", BroadcastFMod);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace oneflow {
namespace one {
struct BroadcastFModCaptureState : public AutoGradCaptureState {
bool requires_grad;
};
class BroadcastFMod : public OpExprGradFunction<BroadcastFModCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
Maybe<void> Capture(BroadcastFModCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const BroadcastFModCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->requires_grad) { in_grads->at(0) = out_grads.at(0); }
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_fmod", BroadcastFMod);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct CastConsistentCaptureState : public AutoGradCaptureState {
Symbol<ParallelDesc> parallel_desc;
Symbol<NdSbp> nd_sbp;
std::shared_ptr<const Shape> shape;
Symbol<DType> dtype;
};
class CastToConsistent : public OpExprGradFunction<CastConsistentCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const CastToConsistentOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
const std::string& op_name = fw_op_expr->op_name();
grad_op_ = JUST(one::CastFromConsistentOpExpr::New(GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> Capture(CastConsistentCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
ctx->parallel_desc = JUST(interp_ctx.parallel_desc);
ctx->nd_sbp = JUST(GetDualNdSbp(JUST(interp_ctx.nd_sbp)));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CastConsistentCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
std::shared_ptr<Tensor> out_grad = out_grads.at(0);
CHECK_OR_RETURN(out_grad->is_consistent())
<< Error::RuntimeError()
<< "Expected global tensor for cast_to_consistent but got local tensor";
{
Symbol<NdSbp> nd_sbp_constraint = ctx->nd_sbp;
Symbol<ParallelDesc> parallel_desc_constraint = ctx->parallel_desc;
out_grad = JUST(functional::ToConsistent(out_grad, parallel_desc_constraint,
*JUST(GetSbpList(nd_sbp_constraint)),
GetNoneSbpList(), /* check_meta */ false));
}
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grad}));
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("cast_to_consistent", CastToConsistent);
class CastFromConsistent : public OpExprGradFunction<CastConsistentCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const CastFromConsistentOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
const std::string& op_name = fw_op_expr->op_name();
grad_op_ = JUST(one::CastToConsistentOpExpr::New(GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> Capture(CastConsistentCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
const auto& input = inputs.at(0);
CHECK_OR_RETURN(input->is_consistent())
<< Error::RuntimeError()
<< "Expected global tensor for cast_from_consistent but got local tensor";
ctx->parallel_desc = JUST(input->parallel_desc());
ctx->nd_sbp = JUST(input->nd_sbp());
ctx->shape = input->shape();
ctx->dtype = input->dtype();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const CastConsistentCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& dual_nd_sbp = JUST(GetDualNdSbp(ctx->nd_sbp));
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", *ctx->shape));
JUST(attrs.SetAttr<DataType>("dtype", ctx->dtype->data_type()));
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(
*grad_op_, {out_grads.at(0)}, OpExprInterpContext(attrs, ctx->parallel_desc, dual_nd_sbp)));
return Maybe<void>::Ok();
}
private:
std::shared_ptr<OpExpr> grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("cast_from_consistent", CastFromConsistent);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/optional.h"
namespace oneflow {
namespace one {
struct ConsistentToConsistentState : public AutoGradCaptureState {
Symbol<ParallelDesc> parallel_desc;
Symbol<NdSbp> nd_sbp;
};
class ConsistentToConsistentGradFunction : public OpExprGradFunction<ConsistentToConsistentState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const ConsistentToConsistentOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
grad_nd_sbp_ = fw_op_expr->grad_nd_sbp();
return Maybe<void>::Ok();
}
Maybe<void> Capture(ConsistentToConsistentState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const OpExprInterpContext& interp_ctx) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->parallel_desc = JUST(inputs.at(0)->parallel_desc());
ctx->nd_sbp = JUST(inputs.at(0)->nd_sbp());
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ConsistentToConsistentState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const auto& out_grad = out_grads.at(0);
CHECK_OR_RETURN(out_grad->is_consistent())
<< Error::RuntimeError()
<< "Expected global tensor for consistent_to_consistent but got local tensor";
in_grads->resize(1);
const auto& grad_nd_sbp = grad_nd_sbp_.value_or(JUST(out_grad->nd_sbp()));
const auto& grad_sbp_list = JUST(GetSbpList(grad_nd_sbp));
const auto& grad_grad_sbp_list = JUST(GetSbpList(ctx->nd_sbp));
(*in_grads)[0] = JUST(one::functional::ToConsistent(
out_grad, ctx->parallel_desc, *grad_sbp_list, *grad_grad_sbp_list, /* check_meta */ false));
return Maybe<void>::Ok();
}
private:
Optional<Symbol<NdSbp>> grad_nd_sbp_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("consistent_to_consistent", ConsistentToConsistentGradFunction);
} // namespace one
} // namespace oneflow
......@@ -26,6 +26,8 @@ namespace one {
struct ConvolutionNdCaptureState : public AutoGradCaptureState {
bool input_requires_grad = false;
bool weight_requires_grad = false;
bool has_bias = false;
bool bias_requires_grad = false;
size_t input_index;
size_t weight_index;
......@@ -58,10 +60,17 @@ Maybe<void> ConvolutionNd::Init(const OpExpr& op) {
Maybe<void> ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_OR_RETURN(inputs.size() == 2 || inputs.size() == 3); // NOLINT(maybe-need-error-msg)
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->weight_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->input_requires_grad && !ctx->weight_requires_grad) { return Maybe<void>::Ok(); }
if (inputs.size() == 3) {
ctx->has_bias = true;
ctx->bias_requires_grad = inputs.at(2)->requires_grad();
}
if (!ctx->input_requires_grad && !ctx->weight_requires_grad && !ctx->bias_requires_grad) {
return Maybe<void>::Ok();
}
if (ctx->input_requires_grad) {
ctx->weight_index = ctx->SaveTensorForBackward(inputs.at(1)); // weight
}
......@@ -79,7 +88,11 @@ Maybe<void> ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorT
Maybe<void> ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
in_grads->resize(2);
if (ctx->has_bias) {
in_grads->resize(3);
} else {
in_grads->resize(2);
}
size_t num_spatial_dims = ctx->kernel_size.size();
if (ctx->input_requires_grad) {
const auto& weight = ctx->SavedTensors().at(ctx->weight_index);
......@@ -94,6 +107,18 @@ Maybe<void> ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const Ten
out_grads.at(0), input, num_spatial_dims, ctx->kernel_size, ctx->strides,
ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));
}
if (ctx->bias_requires_grad) {
std::vector<int32_t> dim;
for (int i = 0; i < out_grads.at(0)->shape()->NumAxes(); ++i) {
if ((ctx->data_format == "channels_first" && i == 1)
|| (ctx->data_format == "channels_last"
&& i == out_grads.at(0)->shape()->NumAxes() - 1)) {
continue;
}
dim.push_back(i);
}
in_grads->at(2) = JUST(functional::ReduceSum(out_grads.at(0), dim, false));
}
return Maybe<void>::Ok();
}
......
......@@ -38,8 +38,14 @@ class Copy : public OpExprGradFunction<CopyCaptureState> {
Maybe<void> Capture(CopyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
ctx->device_type = JUST(inputs.at(0)->device())->type();
ctx->device_id = JUST(inputs.at(0)->device())->device_id();
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
if (inputs[0]->is_global()) {
ctx->device_type = JUST(inputs[0]->parallel_desc())->device_tag();
ctx->device_id = 0; // global tensor only has one local device
} else {
ctx->device_type = JUST(inputs[0]->device())->type();
ctx->device_id = JUST(inputs[0]->device())->device_id();
}
return Maybe<void>::Ok();
}
......
......@@ -57,7 +57,7 @@ Maybe<void> CTCLoss::Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->max_target_length = JUST(composed_attrs.GetAttr<int64_t>("max_target_length"));
ctx->blank = JUST(composed_attrs.GetAttr<int32_t>("blank"));
ctx->blank = JUST(composed_attrs.GetAttr<int64_t>("blank"));
ctx->zero_infinity = JUST(composed_attrs.GetAttr<bool>("zero_infinity"));
CHECK_EQ_OR_RETURN(inputs.size(), 4); // NOLINT(maybe-need-error-msg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment