Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
469
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
638 additions
and
426 deletions
+638
-426
oneflow/core/auto_parallel/sbp_util.h
oneflow/core/auto_parallel/sbp_util.h
+30
-0
oneflow/core/autograd/autograd_engine.cpp
oneflow/core/autograd/autograd_engine.cpp
+158
-69
oneflow/core/autograd/autograd_engine.h
oneflow/core/autograd/autograd_engine.h
+30
-7
oneflow/core/autograd/autograd_meta.cpp
oneflow/core/autograd/autograd_meta.cpp
+17
-6
oneflow/core/autograd/autograd_meta.h
oneflow/core/autograd/autograd_meta.h
+5
-9
oneflow/core/autograd/gradient_funcs/activation.cpp
oneflow/core/autograd/gradient_funcs/activation.cpp
+46
-0
oneflow/core/autograd/gradient_funcs/adaptive_avg_pool.cpp
oneflow/core/autograd/gradient_funcs/adaptive_avg_pool.cpp
+0
-0
oneflow/core/autograd/gradient_funcs/adaptive_max_pool.cpp
oneflow/core/autograd/gradient_funcs/adaptive_max_pool.cpp
+93
-0
oneflow/core/autograd/gradient_funcs/amp_white_identity.cpp
oneflow/core/autograd/gradient_funcs/amp_white_identity.cpp
+62
-0
oneflow/core/autograd/gradient_funcs/as_strided.cpp
oneflow/core/autograd/gradient_funcs/as_strided.cpp
+9
-9
oneflow/core/autograd/gradient_funcs/binary_cross_entropy.cpp
...low/core/autograd/gradient_funcs/binary_cross_entropy.cpp
+26
-28
oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits.cpp
...ograd/gradient_funcs/binary_cross_entropy_with_logits.cpp
+30
-30
oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp
...nt_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp
+16
-16
oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
...low/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
+79
-6
oneflow/core/autograd/gradient_funcs/broadcast_floor_mod.cpp
oneflow/core/autograd/gradient_funcs/broadcast_floor_mod.cpp
+0
-49
oneflow/core/autograd/gradient_funcs/consistent_cast.cpp
oneflow/core/autograd/gradient_funcs/consistent_cast.cpp
+0
-117
oneflow/core/autograd/gradient_funcs/consistent_to_consistent.cpp
...core/autograd/gradient_funcs/consistent_to_consistent.cpp
+0
-74
oneflow/core/autograd/gradient_funcs/conv.cpp
oneflow/core/autograd/gradient_funcs/conv.cpp
+28
-3
oneflow/core/autograd/gradient_funcs/copy.cpp
oneflow/core/autograd/gradient_funcs/copy.cpp
+8
-2
oneflow/core/autograd/gradient_funcs/ctc_loss.cpp
oneflow/core/autograd/gradient_funcs/ctc_loss.cpp
+1
-1
No files found.
Too many changes to show.
To preserve performance only
469 of 469+
files are displayed.
Plain diff
Email patch
oneflow/core/auto_parallel/sbp_util.h
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_
#include "oneflow/core/graph/op_graph.h"
namespace
oneflow
{
namespace
auto_parallel
{
// Judge whether we need the same SBP for both producer and consumer
bool
RequireSameSbp
(
const
OpNode
*
consumer
,
const
std
::
string
&
ibn
);
}
// namespace auto_parallel
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_SBP_UTIL_H_
oneflow/core/autograd/autograd_engine.cpp
View file @
a715222c
...
...
@@ -17,17 +17,25 @@ limitations under the License.
#include <memory>
#include <stack>
#include <queue>
#include "fmt/core.h"
#include "fmt/format.h"
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_arg.h"
#include "oneflow/core/framework/tensor_methods.h"
#include "oneflow/core/framework/tensor_util.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/global_param_grad_sync_mode.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/common/env_var/debug_mode.h"
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
namespace
oneflow
{
namespace
one
{
...
...
@@ -75,20 +83,16 @@ bool IsReadyToRun(const std::vector<std::shared_ptr<AutogradMeta>>& out_meta_dat
Maybe
<
void
>
CopyOrAccGrad
(
AutogradMeta
*
autograd_meta
,
bool
autograd_mode
)
{
autograd
::
AutoGradMode
mode
(
autograd_mode
);
auto
current_grad
=
JUST
(
autograd_meta
->
current_grad
()
->
GetAccTensor
({}
));
auto
current_grad
=
JUST
(
autograd_meta
->
current_grad
_value
(
));
if
(
!
current_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
autograd_meta
->
acc_grad
())
{
// Should not inplace accumulate grad. For example,
// >>> z = x + y
// >>> p = x / z
// >>> p.sum().backward()
//
// As we know that dx = dz + dp / z and dy = dz, so it will lead to wrong value
// for dy if dx is shared with dz.
const
auto
&
output
=
JUST
(
functional
::
Add
(
autograd_meta
->
acc_grad
(),
current_grad
,
/*alpha=*/
1
,
/*inplace=*/
autograd_meta
->
is_grad_acc_inplace
()));
JUST
(
autograd_meta
->
set_acc_grad
(
output
));
JUST
(
functional
::
Add
(
autograd_meta
->
acc_grad
(),
current_grad
,
/*alpha=*/
1.0
,
/*inplace=*/
true
));
}
else
{
// NOTE: acc_grad can not share data with current_grad, because accumulate acc_grad
// with inplace operation and it maybe change current_grad to get wrong result.
// See more details in https://github.com/Oneflow-Inc/oneflow/issues/8248
if
(
!
LazyMode
::
is_enabled
())
{
current_grad
=
JUST
(
functional
::
Identity
(
current_grad
));
}
JUST
(
autograd_meta
->
set_acc_grad
(
current_grad
));
}
for
(
const
auto
&
hook
:
autograd_meta
->
post_grad_accumulation_hooks
())
{
...
...
@@ -99,47 +103,50 @@ Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) {
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
RawTo
r
ch
Consistent
Tensor
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
)
{
Maybe
<
void
>
RawTo
u
ch
Global
Tensor
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
)
{
// Do nothing.
return
Maybe
<
void
>::
Ok
();
}
static
constexpr
auto
*
TorchConsistentTensor
=
DECORATE
(
&
RawTorchConsistentTensor
,
CheckConsistentTensorMeta
);
static
constexpr
auto
*
TouchGlobalTensor
=
DECORATE
(
&
RawTouchGlobalTensor
,
CheckGlobalTensorMeta
);
Maybe
<
void
>
Check
Consistent
TensorsMeta
(
const
TensorTuple
&
tensor_tuple
)
{
Maybe
<
void
>
Check
Global
TensorsMeta
(
const
TensorTuple
&
tensor_tuple
)
{
for
(
const
auto
&
tensor
:
tensor_tuple
)
{
if
(
tensor
->
is_
consistent
())
{
JUST
(
To
r
ch
Consistent
Tensor
(
tensor
));
}
if
(
tensor
->
is_
global
()
&&
tensor
->
is_eager
())
{
JUST
(
To
u
ch
Global
Tensor
(
tensor
));
}
}
return
Maybe
<
void
>::
Ok
();
}
std
::
string
GetDebugGraphFileName
(
const
std
::
string
&
mode
,
const
std
::
string
&
suffix
)
{
return
fmt
::
format
(
"autograd_{}_rank{}_suffix_graph.dot"
,
mode
,
GlobalProcessCtx
::
Rank
(),
suffix
);
}
}
// namespace
Maybe
<
void
>
AutogradEngine
::
RunBackwardAndSaveGrads4LeafTensorIf
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
{
JUST
(
Check
Consistent
TensorsMeta
(
outputs
));
JUST
(
Check
Consistent
TensorsMeta
(
out_grads
));
DisableCheck
Consistent
TensorMetaScope
disable_meta_check
;
JUST
(
Check
Global
TensorsMeta
(
outputs
));
JUST
(
Check
Global
TensorsMeta
(
out_grads
));
DisableCheck
Global
TensorMetaScope
disable_meta_check
;
return
RunBackwardAndSaveGrads4LeafTensor
(
outputs
,
out_grads
,
retain_graph
,
create_graph
);
}
Maybe
<
TensorTuple
>
AutogradEngine
::
RunBackwardAndReturnInputsTensorGradIf
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
{
JUST
(
Check
Consistent
TensorsMeta
(
outputs
));
JUST
(
Check
Consistent
TensorsMeta
(
inputs
));
JUST
(
Check
Consistent
TensorsMeta
(
out_grads
));
DisableCheck
Consistent
TensorMetaScope
disable_meta_check
;
JUST
(
Check
Global
TensorsMeta
(
outputs
));
JUST
(
Check
Global
TensorsMeta
(
inputs
));
JUST
(
Check
Global
TensorsMeta
(
out_grads
));
DisableCheck
Global
TensorMetaScope
disable_meta_check
;
return
RunBackwardAndReturnInputsTensorGrad
(
outputs
,
inputs
,
out_grads
,
retain_graph
,
create_graph
);
}
Maybe
<
void
>
FunctionNode
::
AccGrad4RetainGradTensor
()
{
Maybe
<
void
>
FunctionNode
::
AccGrad4RetainGradTensor
(
bool
create_graph
)
{
for
(
const
std
::
shared_ptr
<
AutogradMeta
>&
out
:
output_meta_data_
)
{
if
(
out
->
retain_grad
())
{
JUST
(
CopyOrAccGrad
(
out
.
get
(),
/*autograd_mode=*/
false
));
}
if
(
out
->
retain_grad
())
{
JUST
(
CopyOrAccGrad
(
out
.
get
(),
create_graph
));
}
}
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -149,17 +156,18 @@ Maybe<void> FunctionNode::AccGrad4LeafTensor(bool create_graph) {
auto
&
out
=
output_meta_data_
[
i
];
if
(
out
->
is_leaf
()
&&
out
->
requires_grad
())
{
JUST
(
CopyOrAccGrad
(
out
.
get
(),
/*autograd_mode=*/
false
));
JUST
(
CopyOrAccGrad
(
out
.
get
(),
/*autograd_mode=*/
create_graph
));
// control acc_grad to do boxing conditionally
const
auto
&
acc_grad
=
out
->
acc_grad
();
if
(
GlobalGradSyncMode
::
is_enabled
()
&&
acc_grad
->
is_consistent
())
{
if
(
!
LazyMode
::
is_enabled
()
&&
GlobalGradSyncMode
::
is_enabled
()
&&
acc_grad
->
is_global
()
&&
acc_grad
->
is_eager
())
{
auto
&
tensor_info
=
output_tensor_infos_
[
i
];
const
auto
&
placement
=
JUST
(
tensor_info
.
placement
());
const
auto
&
nd_sbp
=
JUST
(
tensor_info
.
sbp
());
JUST
(
out
->
set_acc_grad
(
JUST
(
functional
::
To
Consistent
(
acc_grad
,
placement
,
*
JUST
(
GetSbpList
(
nd_sbp
)),
GetNoneSbpList
(),
/* check_meta */
false
))));
JUST
(
functional
::
To
Global
(
acc_grad
,
placement
,
*
JUST
(
GetSbpList
(
nd_sbp
)),
GetNoneSbpList
(),
/* check_meta */
false
,
/*copy=*/
false
))));
}
}
}
...
...
@@ -182,22 +190,30 @@ Maybe<bool> FunctionNode::Apply(bool create_graph) {
TensorTuple
output_grads
(
output_meta_data_
.
size
());
for
(
int
i
=
0
;
i
<
output_meta_data_
.
size
();
++
i
)
{
if
(
output_meta_data_
.
at
(
i
)
->
current_grad
()
->
Empty
())
{
output_grads
.
at
(
i
)
=
JUST
(
output_tensor_infos_
.
at
(
i
).
zeros
());
// Only initialize out_grads for those requires_grad outputs
if
(
output_meta_data_
[
i
]
->
requires_grad
())
{
output_grads
[
i
]
=
JUST
(
output_tensor_infos_
[
i
].
zeros
());
}
}
else
{
const
auto
&
hooks
=
JUST
(
oneflow
::
VectorAt
(
output_meta_data_
,
i
))
->
hooks
();
JUST
(
oneflow
::
VectorAt
(
output_grads
,
i
))
=
JUST
(
JUST
(
oneflow
::
VectorAt
(
output_meta_data_
,
i
))
->
current_grad
()
->
GetAccTensor
(
hooks
));
JUST
(
JUST
(
oneflow
::
VectorAt
(
output_meta_data_
,
i
))
->
current_grad
_value
(
));
}
}
JUST
(
backward_fn_
->
body
(
output_grads
,
&
input_grads
,
create_graph
));
for
(
int
i
=
0
;
i
<
input_meta_data_
.
size
();
++
i
)
{
if
(
JUST
(
VectorAt
(
input_grads
,
i
)))
{
CHECK_NOTNULL_OR_RETURN
(
input_meta_data_
.
at
(
i
)
)
CHECK_NOTNULL_OR_RETURN
(
input_meta_data_
[
i
]
)
<<
name_
<<
" calculate grad for tensor which requires_grad is False. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possible"
;
JUST
(
input_meta_data_
.
at
(
i
)
->
current_grad
()
->
PushPartialTensor
(
input_grads
.
at
(
i
)));
JUST
(
input_meta_data_
[
i
]
->
current_grad
()
->
PushPartialTensor
(
JUST
(
VectorAt
(
input_grads
,
i
))));
}
else
{
CHECK_OR_RETURN
(
!
input_meta_data_
[
i
])
<<
name
()
<<
"'s input["
<<
i
<<
"] need calculate grad but got nullptr. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possible;"
;
}
}
return
true
;
...
...
@@ -247,15 +263,64 @@ GraphTask::GraphTask(const TensorTuple& outputs, bool retain_graph, bool create_
for
(
const
auto
&
out_tensor
:
outputs
)
{
FunctionNode
*
node
=
out_tensor
->
mut_grad_fn_node
().
get
();
roots_
.
emplace_back
(
node
);
dependencies_
.
insert
(
std
::
make_pair
(
node
,
0
));
}
}
Maybe
<
void
>
GraphTask
::
WriteGraphToDotFile
(
const
std
::
string
&
file_name
)
const
{
auto
ExecInfoToDotString
=
[](
const
ExecInfo
&
exec_info
)
->
std
::
string
{
std
::
stringstream
ss
;
ss
<<
"ExecInfo{
\\
l"
;
ss
<<
"
\t
dependencies: "
<<
exec_info
.
dependencies
<<
"
\\
l"
;
ss
<<
"
\t
need_execute: "
<<
exec_info
.
need_execute
<<
"
\\
l"
;
if
(
exec_info
.
capture_indices
)
{
ss
<<
"
\t
capture_indices: ["
;
for
(
const
auto
&
out_idx_and_capture_idx
:
*
exec_info
.
capture_indices
)
{
ss
<<
out_idx_and_capture_idx
.
second
<<
", "
;
}
ss
<<
"]
\\
l"
;
}
ss
<<
"}
\\
l"
;
return
ss
.
str
();
};
auto
log_stream
=
TeePersistentLogStream
::
Create
(
file_name
);
std
::
vector
<
std
::
string
>
lines
;
lines
.
emplace_back
(
"digraph AutogradTaskGraph {"
);
lines
.
emplace_back
(
"
\t
margin=
\"
1.5
\"
;"
);
lines
.
emplace_back
(
"
\t
node [shape=box];"
);
for
(
auto
iter
=
grad_fn2exec_info_
.
begin
();
iter
!=
grad_fn2exec_info_
.
end
();
++
iter
)
{
const
FunctionNode
*
node
=
iter
->
first
;
const
ExecInfo
&
exec_info
=
iter
->
second
;
// write label attribute
std
::
string
node_color
=
"black"
;
if
(
exec_info
.
dependencies
==
0
&&
exec_info
.
need_execute
)
{
// start node
node_color
=
"red"
;
}
else
if
(
exec_info
.
need_execute
&&
exec_info
.
capture_indices
)
{
// end node
node_color
=
"green"
;
}
lines
.
emplace_back
(
fmt
::
format
(
"
\t\"
{}
\"
[label=
\"
{}
\\
l{}
\\
l{}
\"
, color={}];"
,
static_cast
<
const
void
*>
(
node
),
node
->
name
(),
static_cast
<
const
void
*>
(
node
),
ExecInfoToDotString
(
exec_info
),
node_color
));
// write edge
for
(
const
auto
&
next_fn
:
node
->
next_functions
())
{
lines
.
emplace_back
(
fmt
::
format
(
"
\t\"
{}
\"
->
\"
{}
\"
;"
,
static_cast
<
const
void
*>
(
node
),
static_cast
<
const
void
*>
(
next_fn
.
get
())));
}
}
lines
.
emplace_back
(
"}"
);
log_stream
<<
fmt
::
format
(
"{}"
,
fmt
::
join
(
lines
,
"
\n
"
));
log_stream
->
Flush
();
return
Maybe
<
void
>::
Ok
();
}
// Computes the number of dependencies for each FunctionNode
Maybe
<
void
>
GraphTask
::
ComputeDependencies
()
{
HashSet
<
FunctionNode
*>
seen
;
std
::
stack
<
FunctionNode
*>
stack
;
for
(
FunctionNode
*
node
:
roots_
)
{
stack
.
push
(
node
);
}
for
(
FunctionNode
*
node
:
roots_
)
{
stack
.
push
(
node
);
grad_fn2exec_info_
[
node
].
need_execute
=
true
;
}
while
(
!
stack
.
empty
())
{
FunctionNode
*
node
=
stack
.
top
();
...
...
@@ -263,7 +328,9 @@ Maybe<void> GraphTask::ComputeDependencies() {
if
(
/*bool has_seen=*/
!
seen
.
insert
(
node
).
second
)
{
continue
;
}
for
(
const
auto
&
next_grad_fn
:
node
->
next_functions
())
{
FunctionNode
*
next_node
=
next_grad_fn
.
get
();
dependencies_
[
next_node
]
+=
1
;
ExecInfo
&
exec_info
=
grad_fn2exec_info_
[
next_node
];
exec_info
.
dependencies
+=
1
;
exec_info
.
need_execute
=
true
;
if
(
seen
.
find
(
next_node
)
==
seen
.
end
())
{
stack
.
push
(
next_node
);
}
}
}
...
...
@@ -288,9 +355,17 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs
}
};
for
(
const
auto
&
input
:
inputs
)
{
CHECK_NOTNULL_OR_RETURN
(
input
->
mut_grad_fn_node
().
get
());
need_execute_
.
insert
(
input
->
mut_grad_fn_node
().
get
());
// initialize all variable to capture grad for input tensors
captured_grads_
=
std
::
make_shared
<
TensorTuple
>
(
inputs
.
size
());
for
(
int
idx
=
0
;
idx
<
inputs
.
size
();
idx
++
)
{
const
auto
&
input
=
inputs
[
idx
];
CHECK_NOTNULL_OR_RETURN
(
input
->
mut_grad_fn_node
().
get
());
// NOLINT(maybe-need-error-msg)
ExecInfo
&
exec_info
=
grad_fn2exec_info_
[
input
->
mut_grad_fn_node
().
get
()];
exec_info
.
need_execute
=
true
;
if
(
!
exec_info
.
capture_indices
)
{
exec_info
.
capture_indices
=
std
::
make_unique
<
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>>>
();
}
exec_info
.
capture_indices
->
emplace_back
(
std
::
make_pair
(
input
->
get_grad_fn_output_index
(),
idx
));
}
HashSet
<
FunctionNode
*>
seen
;
...
...
@@ -305,18 +380,17 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs
continue
;
}
if
(
FunctionNode
*
node
=
frame
.
GetNextFunction
())
{
dependencies
_
[
node
]
+=
1
;
grad_fn2exec_info_
[
node
].
dependencies
+=
1
;
if
(
seen
.
find
(
node
)
==
seen
.
end
())
{
stack
.
push
(
NodeFrame
(
node
));
continue
;
// recurse
}
}
else
{
bool
need_execute
=
grad_fn2exec_info_
[
frame
.
node_
].
need_execute
|
=
std
::
any_of
(
frame
.
node_
->
next_functions
().
begin
(),
frame
.
node_
->
next_functions
().
end
(),
[
&
](
const
std
::
shared_ptr
<
FunctionNode
>&
fn
)
{
return
need_execute_
.
find
(
fn
.
get
()
)
!=
need_execute
_
.
end
()
;
return
grad_fn2exec_info_
[
fn
.
get
()
].
need_execute
;
});
if
(
need_execute
)
{
need_execute_
.
insert
(
frame
.
node_
);
}
seen
.
insert
(
frame
.
node_
);
stack
.
pop
();
}
...
...
@@ -327,26 +401,38 @@ Maybe<void> GraphTask::ComputeDependenciesAndPruneNode(const TensorTuple& inputs
Maybe
<
void
>
GraphTask
::
Apply
(
bool
save_grad_for_leaf
)
{
std
::
queue
<
FunctionNode
*>
queue
;
for
(
FunctionNode
*
node
:
roots_
)
{
if
(
dependencies
_
[
node
]
==
0
)
{
queue
.
push
(
node
);
}
if
(
grad_fn2exec_info_
[
node
].
dependencies
==
0
)
{
queue
.
push
(
node
);
}
}
while
(
!
queue
.
empty
())
{
FunctionNode
*
node
=
queue
.
front
();
queue
.
pop
();
if
(
!
need_execute_
.
empty
()
&&
need_execute_
.
find
(
node
)
==
need_execute_
.
end
())
{
auto
&
exec_info
=
grad_fn2exec_info_
[
node
];
if
(
!
exec_info
.
need_execute
)
{
node
->
ReleaseOutTensorArgs
();
continue
;
}
BackwardPassScopeGuard
backward_guard
(
node
->
scope
());
if
(
/*bool not_ready_to_apply=*/
!
(
JUST
(
node
->
Apply
(
create_graph_
))))
{
continue
;
}
if
(
exec_info
.
capture_indices
)
{
CHECK_NOTNULL_OR_RETURN
(
captured_grads_
.
get
())
<<
"captured grads in GraphTask is nullptr"
;
for
(
const
auto
&
out_idx_and_capture_idx
:
*
exec_info
.
capture_indices
)
{
JUST
(
VectorAt
(
*
captured_grads_
,
out_idx_and_capture_idx
.
second
))
=
JUST
(
JUST
(
VectorAt
(
node
->
output_meta_data_
,
out_idx_and_capture_idx
.
first
))
->
current_grad_value
());
}
}
if
(
save_grad_for_leaf
)
{
JUST
(
node
->
AccGrad4LeafTensor
(
create_graph_
));
}
JUST
(
node
->
AccGrad4RetainGradTensor
());
JUST
(
node
->
AccGrad4RetainGradTensor
(
create_graph_
));
node
->
ReleaseOutTensorArgs
();
if
(
!
retain_graph_
)
{
node
->
ReleaseData
();
}
for
(
const
auto
&
next_grad_fn
:
node
->
next_functions
())
{
FunctionNode
*
next_node
=
next_grad_fn
.
get
();
dependencies_
[
next_node
]
-=
1
;
if
(
dependencies_
[
next_node
]
==
0
)
{
queue
.
push
(
next_node
);
}
int32_t
&
dependencies
=
grad_fn2exec_info_
[
next_node
].
dependencies
;
dependencies
-=
1
;
if
(
dependencies
==
0
)
{
queue
.
push
(
next_node
);
}
}
}
return
Maybe
<
void
>::
Ok
();
...
...
@@ -361,6 +447,10 @@ Maybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor
}
GraphTask
graph_task
(
outputs
,
retain_graph
,
create_graph
);
JUST
(
graph_task
.
ComputeDependencies
());
if
(
IsInDebugMode
())
{
JUST
(
graph_task
.
WriteGraphToDotFile
(
GetDebugGraphFileName
(
"backward"
,
std
::
to_string
(
clock
()))));
}
JUST
(
graph_task
.
Apply
(
/*save_grad_for_leaf=*/
true
));
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -368,34 +458,23 @@ Maybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor
Maybe
<
TensorTuple
>
GraphAutogradEngine
::
RunBackwardAndReturnInputsTensorGrad
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
{
std
::
shared_ptr
<
TensorTuple
>
input_current_grad
=
std
::
make_shared
<
TensorTuple
>
(
inputs
.
size
());
GraphTask
graph_task
(
outputs
,
retain_graph
,
create_graph
);
std
::
vector
<
bool
>
ori_retain_grad
(
inputs
.
size
());
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
ori_retain_grad
.
at
(
i
)
=
inputs
.
at
(
i
)
->
retain_grad
();
JUST
(
inputs
.
at
(
i
)
->
set_retain_grad
(
true
));
}
for
(
int
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
JUST
(
JUST
(
outputs
.
at
(
i
)
->
current_grad
())
->
PushPartialTensor
(
out_grads
.
at
(
i
)));
}
GraphTask
graph_task
(
outputs
,
retain_graph
,
create_graph
);
JUST
(
graph_task
.
ComputeDependenciesAndPruneNode
(
inputs
));
JUST
(
graph_task
.
Apply
(
/*save_grad_for_leaf=*/
false
));
// Gets input grads and resume retain_grad
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
input_current_grad
->
at
(
i
)
=
JUST
(
inputs
.
at
(
i
)
->
acc_grad
());
if
(
!
ori_retain_grad
.
at
(
i
))
{
JUST
(
inputs
.
at
(
i
)
->
set_acc_grad
(
nullptr
));
JUST
(
inputs
.
at
(
i
)
->
set_retain_grad
(
false
));
}
if
(
IsInDebugMode
())
{
JUST
(
graph_task
.
WriteGraphToDotFile
(
GetDebugGraphFileName
(
"grad"
,
std
::
to_string
(
clock
()))));
}
return
input_current_grad
;
JUST
(
graph_task
.
Apply
(
/*save_grad_for_leaf=*/
false
));
return
graph_task
.
GetCapturedGrads
();
}
Maybe
<
FunctionNode
>
GraphAutogradEngine
::
AddNode
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
BackwardFunction
>&
backward_fn
,
const
TensorTuple
&
inputs
,
TensorTuple
*
outputs
)
{
OF_PROFILER_RANGE_PUSH
(
"AddAccumulateFunctionNode"
);
// Firstly push function_node of tensor in stack which is leaf and requires_grad
for
(
const
std
::
shared_ptr
<
Tensor
>&
in_tensor
:
inputs
)
{
if
(
in_tensor
->
is_leaf
()
&&
in_tensor
->
requires_grad
())
{
...
...
@@ -403,11 +482,17 @@ Maybe<FunctionNode> GraphAutogradEngine::AddNode(
}
}
OF_PROFILER_RANGE_POP
();
OF_PROFILER_RANGE_PUSH
(
"set_grad_fn_node"
);
std
::
shared_ptr
<
FunctionNode
>
func_node
=
GraphFunctionNode
::
New
(
name
,
backward_fn
,
inputs
,
*
outputs
);
for
(
const
std
::
shared_ptr
<
Tensor
>&
out_tensor
:
*
outputs
)
{
for
(
int
i
=
0
;
i
<
outputs
->
size
();
++
i
)
{
const
std
::
shared_ptr
<
Tensor
>&
out_tensor
=
JUST
(
VectorAt
(
*
outputs
,
i
));
out_tensor
->
set_grad_fn_node
(
func_node
);
out_tensor
->
set_grad_fn_output_index
(
i
);
}
if
(
LazyMode
::
is_enabled
())
{
func_node
->
set_scope
(
JUST
(
GetCurrentScope
()));
}
OF_PROFILER_RANGE_POP
();
return
func_node
;
}
...
...
@@ -423,6 +508,10 @@ Maybe<void> AddAccumulateFunctionNode(const std::shared_ptr<Tensor>& tensor) {
backward_fn
->
status
=
[]()
{
return
false
;
};
tensor
->
set_grad_fn_node
(
GraphFunctionNode
::
New
(
"accumulate_grad"
,
backward_fn
,
/*inputs=*/
TensorTuple
{},
/*outputs*/
TensorTuple
{
tensor
}));
tensor
->
set_grad_fn_output_index
(
0
);
if
(
LazyMode
::
is_enabled
())
{
tensor
->
mut_grad_fn_node
()
->
set_scope
(
JUST
(
GetTensorScope
(
tensor
)));
}
return
Maybe
<
void
>::
Ok
();
}
...
...
oneflow/core/autograd/autograd_engine.h
View file @
a715222c
...
...
@@ -17,12 +17,15 @@ limitations under the License.
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
#include <functional>
#include <list>
#include <vector>
#include <memory>
#include <
functional
>
#include "oneflow/core/common/util.h"
#include <
vector
>
#include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/scope_util.h"
#include "oneflow/core/job/lazy_mode.h"
namespace
oneflow
{
...
...
@@ -45,7 +48,7 @@ class FunctionNode {
Maybe
<
bool
>
Apply
(
bool
create_graph
);
Maybe
<
void
>
AccGrad4LeafTensor
(
bool
create_graph
);
Maybe
<
void
>
AccGrad4RetainGradTensor
();
Maybe
<
void
>
AccGrad4RetainGradTensor
(
bool
create_graph
);
void
ReleaseOutTensorArgs
();
// Releases the eventual c++ std::function for backward if retain_graph=False to avoid calling
// `Apply` in second time
...
...
@@ -56,10 +59,14 @@ class FunctionNode {
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
const
std
::
shared_ptr
<
Scope
>&
scope
()
const
{
return
scope_
;
}
void
set_scope
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
{
scope_
=
scope
;
}
protected:
friend
class
GraphTask
;
explicit
FunctionNode
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
BackwardFunction
>&
backward_fn
)
:
name_
(
name
),
backward_fn_
(
backward_fn
)
{}
:
name_
(
name
),
backward_fn_
(
backward_fn
)
,
scope_
(
nullptr
)
{}
const
std
::
string
name_
;
std
::
vector
<
std
::
shared_ptr
<
FunctionNode
>>
next_functions_
;
...
...
@@ -70,6 +77,9 @@ class FunctionNode {
// Actual backward function builds in `AutogradInterpreter` to calculate one backward op
std
::
shared_ptr
<
BackwardFunction
>
backward_fn_
;
// The execution scope
std
::
shared_ptr
<
Scope
>
scope_
;
};
class
AutogradEngine
{
...
...
@@ -130,13 +140,26 @@ class GraphTask final {
Maybe
<
void
>
ComputeDependencies
();
Maybe
<
void
>
ComputeDependenciesAndPruneNode
(
const
TensorTuple
&
inputs
);
Maybe
<
void
>
Apply
(
bool
save_grad_for_leaf
);
std
::
shared_ptr
<
TensorTuple
>
GetCapturedGrads
()
const
{
return
captured_grads_
;
}
Maybe
<
void
>
WriteGraphToDotFile
(
const
std
::
string
&
file_name
)
const
;
private:
class
ExecInfo
{
public:
ExecInfo
()
=
default
;
int32_t
dependencies
=
0
;
bool
need_execute
=
false
;
// Used in autograd.grad interface, to record which grad of tensor will be captured.
// The pair means: <output index of this Node, the index of captured_grads_ to be saved>
std
::
unique_ptr
<
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>>>
capture_indices
;
};
bool
retain_graph_
;
bool
create_graph_
;
std
::
vector
<
FunctionNode
*>
roots_
;
HashMap
<
FunctionNode
*
,
int
>
dependencies
_
;
HashSet
<
FunctionNode
*>
need_execute
_
;
HashMap
<
FunctionNode
*
,
ExecInfo
>
grad_fn2exec_info
_
;
std
::
shared_ptr
<
TensorTuple
>
captured_grads
_
;
};
class
GraphAutogradEngine
final
:
public
AutogradEngine
{
...
...
oneflow/core/autograd/autograd_meta.cpp
View file @
a715222c
...
...
@@ -25,9 +25,12 @@ namespace oneflow {
namespace
one
{
TensorInfo
::
TensorInfo
(
const
Tensor
&
tensor
)
:
shape_
(
tensor
.
shape
()),
dtype_
(
tensor
.
dtype
())
{
if
(
TRY
(
tensor
.
device
()).
IsOk
())
{
device_
=
CHECK_JUST
(
tensor
.
device
());
}
if
(
TRY
(
tensor
.
parallel_desc
()).
IsOk
())
{
parallel_desc_
=
CHECK_JUST
(
tensor
.
parallel_desc
());
}
if
(
TRY
(
tensor
.
nd_sbp
()).
IsOk
())
{
nd_sbp_
=
CHECK_JUST
(
tensor
.
nd_sbp
());
}
if
(
tensor
.
is_global
())
{
parallel_desc_
=
CHECK_JUST
(
tensor
.
parallel_desc
());
nd_sbp_
=
CHECK_JUST
(
tensor
.
nd_sbp
());
}
else
{
device_
=
CHECK_JUST
(
tensor
.
device
());
}
}
Maybe
<
const
std
::
vector
<
Symbol
<
SbpParallel
>>&>
GetSbpTuple
(
Symbol
<
NdSbp
>
nd_sbp
)
{
...
...
@@ -52,7 +55,7 @@ Maybe<Tensor> TensorInfo::zeros() const {
const
auto
&
parallel_desc
=
JUST
(
parallel_desc_
);
const
auto
&
nd_sbp
=
JUST
(
nd_sbp_
);
const
auto
&
sbp_tuple
=
JUST
(
GetSbpTuple
(
nd_sbp
));
return
functional
::
Consistent
Constant
(
*
shape_
.
get
(),
0
,
dtype_
,
parallel_desc
,
sbp_tuple
);
return
functional
::
Global
Constant
(
*
shape_
.
get
(),
0
,
dtype_
,
parallel_desc
,
sbp_tuple
);
}
}
...
...
@@ -60,18 +63,26 @@ AutogradMeta::AutogradMeta(bool requires_grad, bool is_leaf)
:
is_leaf_
(
is_leaf
),
requires_grad_
(
requires_grad
),
retain_grad_
(
false
),
is_grad_acc_inplace_
(
false
),
current_grad_
(
new
TensorArg
)
{}
Maybe
<
void
>
AutogradMeta
::
set_acc_grad
(
const
std
::
shared_ptr
<
Tensor
>&
grad
)
{
if
(
const
auto
&
static_zeros_tensor
=
std
::
dynamic_pointer_cast
<
StaticZerosTensor
>
(
grad
))
{
acc_grad_
=
JUST
(
static_zeros_tensor
->
As
Mirrored
Tensor
());
acc_grad_
=
JUST
(
static_zeros_tensor
->
As
Local
Tensor
());
}
else
{
acc_grad_
=
grad
;
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
Tensor
>
AutogradMeta
::
current_grad_value
()
const
{
std
::
shared_ptr
<
Tensor
>
res
=
JUST
(
current_grad_
->
GetAccTensor
());
for
(
const
auto
&
hook
:
hooks_
)
{
const
auto
&
new_tensor
=
hook
(
res
);
if
(
new_tensor
)
{
res
=
new_tensor
;
}
}
return
res
;
}
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/autograd_meta.h
View file @
a715222c
...
...
@@ -36,7 +36,7 @@ namespace one {
class
Tensor
;
class
TensorArg
;
class
Mirrored
Tensor
;
class
Local
Tensor
;
class
AutogradMeta
final
{
public:
...
...
@@ -46,7 +46,8 @@ class AutogradMeta final {
// Getters
const
std
::
shared_ptr
<
Tensor
>&
acc_grad
()
const
{
return
acc_grad_
;
}
const
std
::
shared_ptr
<
TensorArg
>&
current_grad
()
const
{
return
current_grad_
;
}
bool
is_grad_acc_inplace
()
const
{
return
is_grad_acc_inplace_
;
}
// get current grad processed by hooks
Maybe
<
Tensor
>
current_grad_value
()
const
;
bool
requires_grad
()
const
{
return
requires_grad_
;
}
bool
is_leaf
()
const
{
return
is_leaf_
;
}
bool
retain_grad
()
const
{
return
retain_grad_
;
}
...
...
@@ -59,7 +60,6 @@ class AutogradMeta final {
// Setters
Maybe
<
void
>
set_acc_grad
(
const
std
::
shared_ptr
<
Tensor
>&
grad
);
std
::
shared_ptr
<
Tensor
>
mut_acc_grad
()
{
return
acc_grad_
;
}
void
set_is_grad_acc_inplace
(
bool
is_inplace
)
{
is_grad_acc_inplace_
=
is_inplace
;
}
void
set_requires_grad
(
bool
requires_grad
)
{
requires_grad_
=
requires_grad
;
}
void
set_retain_grad
(
bool
retain_grad
)
{
retain_grad_
=
retain_grad
;
}
void
set_is_leaf
(
bool
is_leaf
)
{
is_leaf_
=
is_leaf
;
}
...
...
@@ -77,10 +77,6 @@ class AutogradMeta final {
// Only meaningful on non_leaf Tensors (must be false otherwise)
bool
retain_grad_
;
// Control whether grad accumulation is inplace. Don't change it
// unless you know what you are doing
bool
is_grad_acc_inplace_
;
std
::
shared_ptr
<
Tensor
>
acc_grad_
;
std
::
shared_ptr
<
TensorArg
>
current_grad_
;
std
::
vector
<
Hook
>
hooks_
;
...
...
@@ -104,8 +100,8 @@ class TensorInfo final {
std
::
shared_ptr
<
const
Shape
>
shape_
;
Symbol
<
DType
>
dtype_
;
Optional
<
Symbol
<
Device
>>
device_
;
// for local tensor
Optional
<
Symbol
<
ParallelDesc
>>
parallel_desc_
;
// for
consistent
tensor
Optional
<
Symbol
<
NdSbp
>>
nd_sbp_
;
// for
consistent
tensor
Optional
<
Symbol
<
ParallelDesc
>>
parallel_desc_
;
// for
global
tensor
Optional
<
Symbol
<
NdSbp
>>
nd_sbp_
;
// for
global
tensor
};
}
// namespace one
...
...
oneflow/core/autograd/gradient_funcs/activation.cpp
View file @
a715222c
...
...
@@ -108,6 +108,50 @@ class GeLU : public BaseActivation {
}
};
class
FastGeLU
:
public
BaseActivation
{
public:
Maybe
<
void
>
Apply
(
const
BaseActivationCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
FastGeluGrad
(
out_grads
.
at
(
0
),
x
));
}
return
Maybe
<
void
>::
Ok
();
}
};
struct
QuickGeluCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
};
class
QuickGeLU
:
public
OpExprGradFunction
<
QuickGeluCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
QuickGeluCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
QuickGeluCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
QuickGeluGrad
(
out_grads
.
at
(
0
),
x
));
}
return
Maybe
<
void
>::
Ok
();
}
};
class
HardSigmoid
:
public
BaseActivation
{
public:
Maybe
<
void
>
Apply
(
const
BaseActivationCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
...
...
@@ -558,6 +602,8 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("prelu", PReLU);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"threshold"
,
Threshold
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"softplus"
,
Softplus
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"softshrink"
,
SoftShrink
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"fast_gelu"
,
FastGeLU
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"quick_gelu"
,
QuickGeLU
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp
→
oneflow/core/autograd/gradient_funcs/adaptive_
avg_
pool.cpp
View file @
a715222c
File moved
oneflow/core/autograd/gradient_funcs/adaptive_max_pool.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
AdaptiveMaxPoolCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
};
class
AdaptiveMaxPoolNdGrad
:
public
OpExprGradFunction
<
AdaptiveMaxPoolCaptureState
>
{
public:
using
OpExprGradFunction
<
AdaptiveMaxPoolCaptureState
>::
Init
;
Maybe
<
void
>
Init
(
const
OpExpr
&
op
,
const
int
&
ndims
);
Maybe
<
void
>
Capture
(
AdaptiveMaxPoolCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
AdaptiveMaxPoolCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
int32_t
ndims_
=
0
;
};
Maybe
<
void
>
AdaptiveMaxPoolNdGrad
::
Init
(
const
OpExpr
&
op
,
const
int
&
ndims
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
ndims_
=
ndims
;
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
AdaptiveMaxPoolNdGrad
::
Capture
(
AdaptiveMaxPoolCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
1
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
AdaptiveMaxPoolNdGrad
::
Apply
(
const
AdaptiveMaxPoolCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
index
=
ctx
->
SavedTensors
().
at
(
1
);
in_grads
->
resize
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
AdaptiveMaxPoolNdGrad
(
x
,
out_grads
.
at
(
0
),
index
,
ndims_
));
return
Maybe
<
void
>::
Ok
();
}
class
AdaptiveMaxPool1dGrad
final
:
public
AdaptiveMaxPoolNdGrad
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
AdaptiveMaxPoolNdGrad
::
Init
(
op
,
1
);
}
};
class
AdaptiveMaxPool2dGrad
final
:
public
AdaptiveMaxPoolNdGrad
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
AdaptiveMaxPoolNdGrad
::
Init
(
op
,
2
);
}
};
class
AdaptiveMaxPool3dGrad
final
:
public
AdaptiveMaxPoolNdGrad
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
AdaptiveMaxPoolNdGrad
::
Init
(
op
,
3
);
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"adaptive_max_pool1d"
,
AdaptiveMaxPool1dGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"adaptive_max_pool2d"
,
AdaptiveMaxPool2dGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"adaptive_max_pool3d"
,
AdaptiveMaxPool3dGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/amp_white_identity.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
enum
class
AmpIdentityType
{
kWhite
=
0
,
kBlack
,
};
struct
AmpIdentityCaptureState
:
public
AutoGradCaptureState
{};
template
<
AmpIdentityType
type
>
class
AmpIdentityGrad
:
public
OpExprGradFunction
<
AmpIdentityCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
AmpIdentityCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
AmpIdentityCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
1
);
if
(
type
==
AmpIdentityType
::
kWhite
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
AmpWhiteIdentity
(
out_grads
[
0
]));
}
else
if
(
type
==
AmpIdentityType
::
kBlack
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
AmpBlackIdentity
(
out_grads
[
0
]));
}
else
{
(
*
in_grads
)[
0
]
=
out_grads
[
0
];
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"amp_white_identity"
,
AmpIdentityGrad
<
AmpIdentityType
::
kWhite
>
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"amp_black_identity"
,
AmpIdentityGrad
<
AmpIdentityType
::
kBlack
>
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/as_strided.cpp
View file @
a715222c
...
...
@@ -22,9 +22,9 @@ namespace oneflow {
namespace
one
{
struct
AsStridedCaptureState
:
public
AutoGradCaptureState
{
std
::
vector
<
int
32
_t
>
size
;
std
::
vector
<
int
32
_t
>
stride
;
int
32
_t
storage_offset
=
0
;
std
::
vector
<
int
64
_t
>
size
;
std
::
vector
<
int
64
_t
>
stride
;
int
64
_t
storage_offset
=
0
;
bool
requires_grad
=
false
;
};
...
...
@@ -55,9 +55,9 @@ Maybe<void> AsStrided::Capture(AsStridedCaptureState* ctx, const TensorTuple& in
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int
32
_t
>>
(
"size"
));
ctx
->
stride
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int
32
_t
>>
(
"stride"
));
ctx
->
storage_offset
=
JUST
(
composed_attrs
.
GetAttr
<
int
32
_t
>
(
"storage_offset"
));
ctx
->
size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int
64
_t
>>
(
"size"
));
ctx
->
stride
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int
64
_t
>>
(
"stride"
));
ctx
->
storage_offset
=
JUST
(
composed_attrs
.
GetAttr
<
int
64
_t
>
(
"storage_offset"
));
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -67,9 +67,9 @@ Maybe<void> AsStrided::Apply(const AsStridedCaptureState* ctx, const TensorTuple
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
input
=
ctx
->
SavedTensors
().
at
(
0
);
std
::
vector
<
int
32
_t
>
size
=
ctx
->
size
;
std
::
vector
<
int
32
_t
>
stride
=
ctx
->
stride
;
int
32
_t
storage_offset
=
ctx
->
storage_offset
;
std
::
vector
<
int
64
_t
>
size
=
ctx
->
size
;
std
::
vector
<
int
64
_t
>
stride
=
ctx
->
stride
;
int
64
_t
storage_offset
=
ctx
->
storage_offset
;
in_grads
->
at
(
0
)
=
JUST
(
functional
::
AsStridedGrad
(
out_grads
.
at
(
0
),
input
,
size
,
stride
,
storage_offset
));
...
...
oneflow/core/autograd/gradient_funcs/binary_cross_entropy.cpp
View file @
a715222c
...
...
@@ -20,7 +20,9 @@ namespace oneflow {
namespace
one
{
struct
BinaryCrossEntropyCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
bool
input_requires_grad
=
false
;
bool
target_requires_grad
=
false
;
bool
has_weight
=
false
;
};
class
BinaryCrossEntropy
:
public
OpExprGradFunction
<
BinaryCrossEntropyCaptureState
>
{
...
...
@@ -30,46 +32,42 @@ class BinaryCrossEntropy : public OpExprGradFunction<BinaryCrossEntropyCaptureSt
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
BinaryCrossEntropyCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
BinaryCrossEntropy
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropy
::
Init
(
const
OpExpr
&
op
)
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropy
::
Capture
(
BinaryCrossEntropyCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_OR_RETURN
(
inputs
.
size
()
>=
2
&&
inputs
.
size
()
<=
3
);
// NOLINT(maybe-need-error-msg)
ctx
->
input_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
target_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
has_weight
=
inputs
.
size
()
==
3
;
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
// input
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
// target
if
(
inputs
.
size
()
==
3
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
));
// weight
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
// input
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
// target
if
(
ctx
->
has_weight
)
{
ctx
->
SaveTensorForBackward
(
inputs
[
2
]);
// weight
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropy
::
Apply
(
const
BinaryCrossEntropyCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
dy
=
out_grads
.
at
(
0
);
const
auto
&
input
=
ctx
->
SavedTensors
().
at
(
0
);
const
auto
&
target
=
ctx
->
SavedTensors
().
at
(
1
);
in_grads
->
resize
(
ctx
->
SavedTensors
().
size
());
CHECK_EQ_OR_RETURN
(
ctx
->
SavedTensors
().
size
(),
2
+
ctx
->
has_weight
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
2
+
ctx
->
has_weight
);
const
auto
&
dy
=
out_grads
[
0
];
const
auto
&
input
=
ctx
->
SavedTensors
()[
0
];
const
auto
&
target
=
ctx
->
SavedTensors
()[
1
];
const
auto
&
weight
=
ctx
->
has_weight
?
Optional
<
one
::
Tensor
>
(
ctx
->
SavedTensors
()[
2
])
:
NullOpt
;
if
(
ctx
->
SavedTensors
().
size
()
==
3
)
{
const
auto
&
weight
=
ctx
->
SavedTensors
().
at
(
2
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyLossGrad
(
dy
,
input
,
target
,
weight
));
}
else
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyLossGrad
(
dy
,
input
,
target
,
NullOp
t
));
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
BinaryCrossEntropyLossGrad
(
dy
,
input
,
target
,
weight
)
);
}
if
(
ctx
->
target_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
BinaryCrossEntropyLoss
Target
Grad
(
dy
,
input
,
target
,
weigh
t
));
}
return
Maybe
<
void
>::
Ok
();
}
...
...
oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits.cpp
View file @
a715222c
...
...
@@ -20,7 +20,9 @@ namespace oneflow {
namespace
one
{
struct
BinaryCrossEntropyWithLogitsCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
bool
input_requires_grad
=
false
;
bool
target_requires_grad
=
false
;
bool
has_weight
=
false
;
bool
has_pos_weight
=
false
;
};
...
...
@@ -47,53 +49,51 @@ Maybe<void> BinaryCrossEntropyWithLogits::Capture(BinaryCrossEntropyWithLogitsCa
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_OR_RETURN
(
inputs
.
size
()
>=
2
&&
inputs
.
size
()
<=
4
);
// NOLINT(maybe-need-error-msg)
ctx
->
input_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
target_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
has_pos_weight
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"has_pos_weight"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
// input
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
// target
ctx
->
has_weight
=
inputs
.
size
()
==
4
||
(
inputs
.
size
()
==
3
&&
!
ctx
->
has_pos_weight
);
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
// input
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
// target
if
(
inputs
.
size
()
==
3
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
)
);
// weight or pos_weight
ctx
->
SaveTensorForBackward
(
inputs
[
2
]
);
// weight or pos_weight
}
if
(
inputs
.
size
()
==
4
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
)
);
// weight
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
3
)
);
// pos_weight
ctx
->
SaveTensorForBackward
(
inputs
[
2
]
);
// weight
ctx
->
SaveTensorForBackward
(
inputs
[
3
]
);
// pos_weight
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyWithLogits
::
Apply
(
const
BinaryCrossEntropyWithLogitsCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
dy
=
out_grads
.
at
(
0
);
const
auto
&
input
=
ctx
->
SavedTensors
().
at
(
0
);
const
auto
&
target
=
ctx
->
SavedTensors
().
at
(
1
);
CHECK_EQ_OR_RETURN
(
ctx
->
SavedTensors
().
size
(),
2
+
ctx
->
has_weight
+
ctx
->
has_pos_weight
);
// NOLINT(maybe-need-error-msg)
const
auto
&
dy
=
out_grads
[
0
];
const
auto
&
input
=
ctx
->
SavedTensors
()[
0
];
const
auto
&
target
=
ctx
->
SavedTensors
()[
1
];
in_grads
->
resize
(
ctx
->
SavedTensors
().
size
());
if
(
ctx
->
SavedTensors
().
size
()
==
3
)
{
if
(
ctx
->
has_pos_weight
)
{
const
auto
&
pos_weight
=
ctx
->
SavedTensors
().
at
(
2
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsLossGrad
(
dy
,
input
,
target
,
NullOpt
,
pos_weight
));
}
else
{
const
auto
&
weight
=
ctx
->
SavedTensors
().
at
(
2
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsLossGrad
(
dy
,
input
,
target
,
weight
,
NullOpt
));
}
}
else
if
(
ctx
->
SavedTensors
().
size
()
==
4
)
{
const
auto
&
weight
=
ctx
->
SavedTensors
().
at
(
2
);
const
auto
&
pos_weight
=
ctx
->
SavedTensors
().
at
(
3
);
in_grads
->
at
(
0
)
=
JUST
(
size_t
pos_weight_index
=
ctx
->
has_weight
?
3
:
2
;
auto
weight
=
ctx
->
has_weight
?
Optional
<
one
::
Tensor
>
(
ctx
->
SavedTensors
()[
2
])
:
NullOpt
;
auto
pos_weight
=
ctx
->
has_pos_weight
?
Optional
<
one
::
Tensor
>
(
ctx
->
SavedTensors
()[
pos_weight_index
])
:
NullOpt
;
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsLossGrad
(
dy
,
input
,
target
,
weight
,
pos_weight
));
}
else
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsLossGrad
(
dy
,
input
,
target
,
NullOpt
,
NullOpt
));
}
if
(
ctx
->
target_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsLossTargetGrad
(
dy
,
input
,
target
,
weight
,
pos_weight
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"binary_cross_entropy_with_logits"
,
BinaryCrossEntropyWithLogits
);
...
...
oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp
View file @
a715222c
...
...
@@ -21,8 +21,8 @@ namespace oneflow {
namespace
one
{
struct
BinaryCrossEntropyWithLogitsReduceMeanCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
bool
has_pos_weight
=
false
;
bool
input_
requires_grad
=
false
;
bool
target_requires_grad
=
false
;
};
class
BinaryCrossEntropyWithLogitsReduceMean
...
...
@@ -34,25 +34,19 @@ class BinaryCrossEntropyWithLogitsReduceMean
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
BinaryCrossEntropyWithLogitsReduceMeanCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
BinaryCrossEntropyWithLogitsReduceMean
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
)
<<
"fw_op_expr should not be null. "
;
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyWithLogitsReduceMean
::
Capture
(
BinaryCrossEntropyWithLogitsReduceMeanCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
JUST
(
VectorAt
(
inputs
,
0
))
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
input_requires_grad
=
JUST
(
VectorAt
(
inputs
,
0
))
->
requires_grad
();
ctx
->
target_requires_grad
=
JUST
(
VectorAt
(
inputs
,
1
))
->
requires_grad
();
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
0
)));
// input
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
1
)));
// target
return
Maybe
<
void
>::
Ok
();
...
...
@@ -61,14 +55,20 @@ Maybe<void> BinaryCrossEntropyWithLogitsReduceMean::Capture(
Maybe
<
void
>
BinaryCrossEntropyWithLogitsReduceMean
::
Apply
(
const
BinaryCrossEntropyWithLogitsReduceMeanCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
)
<<
"out_grads size should be equal to 1. "
;
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
dy
=
JUST
(
VectorAt
(
out_grads
,
0
));
const
auto
&
input
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
0
));
const
auto
&
target
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
1
));
in_grads
->
resize
(
ctx
->
SavedTensors
().
size
());
JUST
(
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsReduceMeanLossGrad
(
dy
,
input
,
target
));
in_grads
->
resize
(
2
);
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsReduceMeanLossGrad
(
dy
,
input
,
target
));
}
if
(
ctx
->
target_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsReduceMeanLossTargetGrad
(
dy
,
input
,
target
));
}
return
Maybe
<
void
>::
Ok
();
}
...
...
oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
View file @
a715222c
...
...
@@ -232,13 +232,12 @@ class BroadcastPow : public BroadcastBinaryGrad {
TensorTuple
*
in_grads
)
const
override
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
const
auto
&
z
=
ctx
->
SavedTensors
().
at
(
ctx
->
z_index
);
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BroadcastPowXGrad
(
out_grads
.
at
(
0
),
x
,
y
,
z
));
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
BroadcastPowXGrad
(
x
,
y
,
out_grads
[
0
]
));
}
if
(
ctx
->
y_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
BroadcastPowYGrad
(
out_grads
.
at
(
0
),
x
,
y
,
z
));
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
BroadcastPowYGrad
(
x
,
y
,
out_grads
[
0
]
));
}
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -246,9 +245,8 @@ class BroadcastPow : public BroadcastBinaryGrad {
protected:
Maybe
<
void
>
SaveTensorForBackward
(
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
const
override
{
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
ctx
->
z_index
=
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
0
));
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
return
Maybe
<
void
>::
Ok
();
}
};
...
...
@@ -348,5 +346,80 @@ class BroadcastMaximum : public BroadcastMinMax {
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_minimum"
,
BroadcastMinimum
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_maximum"
,
BroadcastMaximum
);
class
BroadcastFMod
:
public
BroadcastBinaryGrad
{
public:
Maybe
<
void
>
Apply
(
const
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
const
auto
&
out_shape
=
*
(
JUST
(
VectorAt
(
out_grads
,
0
))
->
shape
());
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
||
ctx
->
y_requires_grad
)
{
const
auto
&
x
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
x_index
));
const
auto
&
y
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
y_index
));
auto
broad_x_
=
x
;
auto
broad_y_
=
y
;
if
(
ctx
->
broadcast_x
)
{
const
auto
&
x_shape
=
*
(
x
->
shape
());
const
Shape
&
left_extended_x_shape
=
CreateLeftExtendedShape
(
ShapeView
(
x_shape
),
out_shape
.
NumAxes
());
if
(
left_extended_x_shape
==
out_shape
)
{
broad_x_
=
JUST
(
functional
::
ReshapeLike
(
x
,
JUST
(
VectorAt
(
out_grads
,
0
))));
}
else
{
const
AxisVector
&
broadcast_axis_vec
=
left_extended_x_shape
.
Axes4BroadcastTo
(
out_shape
);
const
std
::
vector
<
int32_t
>
x_axis
=
std
::
vector
<
int32_t
>
{
broadcast_axis_vec
.
begin
(),
broadcast_axis_vec
.
end
()};
broad_x_
=
JUST
(
functional
::
BroadcastLike
(
x
,
JUST
(
VectorAt
(
out_grads
,
0
)),
x_axis
));
}
}
if
(
ctx
->
broadcast_y
)
{
const
auto
&
y_shape
=
*
(
y
->
shape
());
const
Shape
&
left_extended_y_shape
=
CreateLeftExtendedShape
(
ShapeView
(
y_shape
),
out_shape
.
NumAxes
());
if
(
left_extended_y_shape
==
out_shape
)
{
broad_y_
=
JUST
(
functional
::
ReshapeLike
(
y
,
JUST
(
VectorAt
(
out_grads
,
0
))));
}
else
{
const
AxisVector
&
broadcast_axis_vec
=
left_extended_y_shape
.
Axes4BroadcastTo
(
out_shape
);
const
std
::
vector
<
int32_t
>
y_axis
=
std
::
vector
<
int32_t
>
{
broadcast_axis_vec
.
begin
(),
broadcast_axis_vec
.
end
()};
broad_y_
=
JUST
(
functional
::
BroadcastLike
(
y
,
JUST
(
VectorAt
(
out_grads
,
0
)),
y_axis
));
}
}
if
(
ctx
->
x_requires_grad
)
{
if
(
ctx
->
broadcast_x
)
{
JUST
(
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
BroadcastReduceSumLike
(
JUST
(
VectorAt
(
out_grads
,
0
)),
x
));
}
else
{
JUST
(
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
VectorAt
(
out_grads
,
0
));
}
}
if
(
ctx
->
y_requires_grad
)
{
auto
result
=
JUST
(
functional
::
TruncDiv
(
broad_x_
,
broad_y_
));
result
=
JUST
(
functional
::
Mul
(
JUST
(
VectorAt
(
out_grads
,
0
)),
result
));
JUST
(
functional
::
ScalarMul
(
result
,
Scalar
(
-
1.
f
),
true
));
if
(
ctx
->
broadcast_y
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
result
,
y
));
}
else
{
in_grads
->
at
(
1
)
=
result
;
}
}
}
return
Maybe
<
void
>::
Ok
();
}
protected:
Maybe
<
void
>
SaveTensorForBackward
(
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
const
override
{
if
(
ctx
->
x_requires_grad
&&
ctx
->
broadcast_x
)
{
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
0
)));
}
if
(
ctx
->
y_requires_grad
)
{
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
0
)));
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
1
)));
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_fmod"
,
BroadcastFMod
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/broadcast_floor_mod.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace
oneflow
{
namespace
one
{
struct
BroadcastFModCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
};
class
BroadcastFMod
:
public
OpExprGradFunction
<
BroadcastFModCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
BroadcastFModCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
BroadcastFModCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
2
);
if
(
ctx
->
requires_grad
)
{
in_grads
->
at
(
0
)
=
out_grads
.
at
(
0
);
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_fmod"
,
BroadcastFMod
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/consistent_cast.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
CastConsistentCaptureState
:
public
AutoGradCaptureState
{
Symbol
<
ParallelDesc
>
parallel_desc
;
Symbol
<
NdSbp
>
nd_sbp
;
std
::
shared_ptr
<
const
Shape
>
shape
;
Symbol
<
DType
>
dtype
;
};
class
CastToConsistent
:
public
OpExprGradFunction
<
CastConsistentCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
CastToConsistentOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
const
std
::
string
&
op_name
=
fw_op_expr
->
op_name
();
grad_op_
=
JUST
(
one
::
CastFromConsistentOpExpr
::
New
(
GradientOpName
(
op_name
)));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
CastConsistentCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
OpExprInterpContext
&
interp_ctx
)
const
override
{
ctx
->
parallel_desc
=
JUST
(
interp_ctx
.
parallel_desc
);
ctx
->
nd_sbp
=
JUST
(
GetDualNdSbp
(
JUST
(
interp_ctx
.
nd_sbp
)));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
CastConsistentCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
std
::
shared_ptr
<
Tensor
>
out_grad
=
out_grads
.
at
(
0
);
CHECK_OR_RETURN
(
out_grad
->
is_consistent
())
<<
Error
::
RuntimeError
()
<<
"Expected global tensor for cast_to_consistent but got local tensor"
;
{
Symbol
<
NdSbp
>
nd_sbp_constraint
=
ctx
->
nd_sbp
;
Symbol
<
ParallelDesc
>
parallel_desc_constraint
=
ctx
->
parallel_desc
;
out_grad
=
JUST
(
functional
::
ToConsistent
(
out_grad
,
parallel_desc_constraint
,
*
JUST
(
GetSbpList
(
nd_sbp_constraint
)),
GetNoneSbpList
(),
/* check_meta */
false
));
}
in_grads
->
at
(
0
)
=
JUST
(
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
grad_op_
,
{
out_grad
}));
return
Maybe
<
void
>::
Ok
();
}
private:
std
::
shared_ptr
<
OpExpr
>
grad_op_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"cast_to_consistent"
,
CastToConsistent
);
class
CastFromConsistent
:
public
OpExprGradFunction
<
CastConsistentCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
CastFromConsistentOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
const
std
::
string
&
op_name
=
fw_op_expr
->
op_name
();
grad_op_
=
JUST
(
one
::
CastToConsistentOpExpr
::
New
(
GradientOpName
(
op_name
)));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
CastConsistentCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
const
auto
&
input
=
inputs
.
at
(
0
);
CHECK_OR_RETURN
(
input
->
is_consistent
())
<<
Error
::
RuntimeError
()
<<
"Expected global tensor for cast_from_consistent but got local tensor"
;
ctx
->
parallel_desc
=
JUST
(
input
->
parallel_desc
());
ctx
->
nd_sbp
=
JUST
(
input
->
nd_sbp
());
ctx
->
shape
=
input
->
shape
();
ctx
->
dtype
=
input
->
dtype
();
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
CastConsistentCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
const
auto
&
dual_nd_sbp
=
JUST
(
GetDualNdSbp
(
ctx
->
nd_sbp
));
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
<
Shape
>
(
"shape"
,
*
ctx
->
shape
));
JUST
(
attrs
.
SetAttr
<
DataType
>
(
"dtype"
,
ctx
->
dtype
->
data_type
()));
in_grads
->
at
(
0
)
=
JUST
(
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
grad_op_
,
{
out_grads
.
at
(
0
)},
OpExprInterpContext
(
attrs
,
ctx
->
parallel_desc
,
dual_nd_sbp
)));
return
Maybe
<
void
>::
Ok
();
}
private:
std
::
shared_ptr
<
OpExpr
>
grad_op_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"cast_from_consistent"
,
CastFromConsistent
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/consistent_to_consistent.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/optional.h"
namespace
oneflow
{
namespace
one
{
struct
ConsistentToConsistentState
:
public
AutoGradCaptureState
{
Symbol
<
ParallelDesc
>
parallel_desc
;
Symbol
<
NdSbp
>
nd_sbp
;
};
class
ConsistentToConsistentGradFunction
:
public
OpExprGradFunction
<
ConsistentToConsistentState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
ConsistentToConsistentOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
grad_nd_sbp_
=
fw_op_expr
->
grad_nd_sbp
();
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ConsistentToConsistentState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
OpExprInterpContext
&
interp_ctx
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
parallel_desc
=
JUST
(
inputs
.
at
(
0
)
->
parallel_desc
());
ctx
->
nd_sbp
=
JUST
(
inputs
.
at
(
0
)
->
nd_sbp
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ConsistentToConsistentState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
out_grad
=
out_grads
.
at
(
0
);
CHECK_OR_RETURN
(
out_grad
->
is_consistent
())
<<
Error
::
RuntimeError
()
<<
"Expected global tensor for consistent_to_consistent but got local tensor"
;
in_grads
->
resize
(
1
);
const
auto
&
grad_nd_sbp
=
grad_nd_sbp_
.
value_or
(
JUST
(
out_grad
->
nd_sbp
()));
const
auto
&
grad_sbp_list
=
JUST
(
GetSbpList
(
grad_nd_sbp
));
const
auto
&
grad_grad_sbp_list
=
JUST
(
GetSbpList
(
ctx
->
nd_sbp
));
(
*
in_grads
)[
0
]
=
JUST
(
one
::
functional
::
ToConsistent
(
out_grad
,
ctx
->
parallel_desc
,
*
grad_sbp_list
,
*
grad_grad_sbp_list
,
/* check_meta */
false
));
return
Maybe
<
void
>::
Ok
();
}
private:
Optional
<
Symbol
<
NdSbp
>>
grad_nd_sbp_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"consistent_to_consistent"
,
ConsistentToConsistentGradFunction
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/conv.cpp
View file @
a715222c
...
...
@@ -26,6 +26,8 @@ namespace one {
struct
ConvolutionNdCaptureState
:
public
AutoGradCaptureState
{
bool
input_requires_grad
=
false
;
bool
weight_requires_grad
=
false
;
bool
has_bias
=
false
;
bool
bias_requires_grad
=
false
;
size_t
input_index
;
size_t
weight_index
;
...
...
@@ -58,10 +60,17 @@ Maybe<void> ConvolutionNd::Init(const OpExpr& op) {
Maybe
<
void
>
ConvolutionNd
::
Capture
(
ConvolutionNdCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
CHECK_
EQ_
OR_RETURN
(
inputs
.
size
()
,
2
);
// NOLINT(maybe-need-error-msg)
CHECK_OR_RETURN
(
inputs
.
size
()
==
2
||
inputs
.
size
()
==
3
);
// NOLINT(maybe-need-error-msg)
ctx
->
input_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
weight_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
!
ctx
->
input_requires_grad
&&
!
ctx
->
weight_requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
inputs
.
size
()
==
3
)
{
ctx
->
has_bias
=
true
;
ctx
->
bias_requires_grad
=
inputs
.
at
(
2
)
->
requires_grad
();
}
if
(
!
ctx
->
input_requires_grad
&&
!
ctx
->
weight_requires_grad
&&
!
ctx
->
bias_requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
ctx
->
input_requires_grad
)
{
ctx
->
weight_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
// weight
}
...
...
@@ -79,7 +88,11 @@ Maybe<void> ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorT
Maybe
<
void
>
ConvolutionNd
::
Apply
(
const
ConvolutionNdCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
in_grads
->
resize
(
2
);
if
(
ctx
->
has_bias
)
{
in_grads
->
resize
(
3
);
}
else
{
in_grads
->
resize
(
2
);
}
size_t
num_spatial_dims
=
ctx
->
kernel_size
.
size
();
if
(
ctx
->
input_requires_grad
)
{
const
auto
&
weight
=
ctx
->
SavedTensors
().
at
(
ctx
->
weight_index
);
...
...
@@ -94,6 +107,18 @@ Maybe<void> ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const Ten
out_grads
.
at
(
0
),
input
,
num_spatial_dims
,
ctx
->
kernel_size
,
ctx
->
strides
,
ctx
->
padding_before
,
ctx
->
dilation_rate
,
ctx
->
groups
,
ctx
->
data_format
));
}
if
(
ctx
->
bias_requires_grad
)
{
std
::
vector
<
int32_t
>
dim
;
for
(
int
i
=
0
;
i
<
out_grads
.
at
(
0
)
->
shape
()
->
NumAxes
();
++
i
)
{
if
((
ctx
->
data_format
==
"channels_first"
&&
i
==
1
)
||
(
ctx
->
data_format
==
"channels_last"
&&
i
==
out_grads
.
at
(
0
)
->
shape
()
->
NumAxes
()
-
1
))
{
continue
;
}
dim
.
push_back
(
i
);
}
in_grads
->
at
(
2
)
=
JUST
(
functional
::
ReduceSum
(
out_grads
.
at
(
0
),
dim
,
false
));
}
return
Maybe
<
void
>::
Ok
();
}
...
...
oneflow/core/autograd/gradient_funcs/copy.cpp
View file @
a715222c
...
...
@@ -38,8 +38,14 @@ class Copy : public OpExprGradFunction<CopyCaptureState> {
Maybe
<
void
>
Capture
(
CopyCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
ctx
->
device_type
=
JUST
(
inputs
.
at
(
0
)
->
device
())
->
type
();
ctx
->
device_id
=
JUST
(
inputs
.
at
(
0
)
->
device
())
->
device_id
();
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
if
(
inputs
[
0
]
->
is_global
())
{
ctx
->
device_type
=
JUST
(
inputs
[
0
]
->
parallel_desc
())
->
device_tag
();
ctx
->
device_id
=
0
;
// global tensor only has one local device
}
else
{
ctx
->
device_type
=
JUST
(
inputs
[
0
]
->
device
())
->
type
();
ctx
->
device_id
=
JUST
(
inputs
[
0
]
->
device
())
->
device_id
();
}
return
Maybe
<
void
>::
Ok
();
}
...
...
oneflow/core/autograd/gradient_funcs/ctc_loss.cpp
View file @
a715222c
...
...
@@ -57,7 +57,7 @@ Maybe<void> CTCLoss::Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
max_target_length
=
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"max_target_length"
));
ctx
->
blank
=
JUST
(
composed_attrs
.
GetAttr
<
int
32
_t
>
(
"blank"
));
ctx
->
blank
=
JUST
(
composed_attrs
.
GetAttr
<
int
64
_t
>
(
"blank"
));
ctx
->
zero_infinity
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"zero_infinity"
));
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
4
);
// NOLINT(maybe-need-error-msg)
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment