Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
469
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1777 additions
and
56 deletions
+1777
-56
oneflow/core/autograd/gradient_funcs/reshape.cpp
oneflow/core/autograd/gradient_funcs/reshape.cpp
+29
-2
oneflow/core/autograd/gradient_funcs/rms_norm.cpp
oneflow/core/autograd/gradient_funcs/rms_norm.cpp
+99
-0
oneflow/core/autograd/gradient_funcs/scalar_floordiv.cpp
oneflow/core/autograd/gradient_funcs/scalar_floordiv.cpp
+12
-7
oneflow/core/autograd/gradient_funcs/scalar_pow.cpp
oneflow/core/autograd/gradient_funcs/scalar_pow.cpp
+0
-4
oneflow/core/autograd/gradient_funcs/scalar_truncdiv.cpp
oneflow/core/autograd/gradient_funcs/scalar_truncdiv.cpp
+53
-0
oneflow/core/autograd/gradient_funcs/slice.cpp
oneflow/core/autograd/gradient_funcs/slice.cpp
+2
-3
oneflow/core/autograd/gradient_funcs/smooth_l1_loss.cpp
oneflow/core/autograd/gradient_funcs/smooth_l1_loss.cpp
+14
-13
oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp
.../autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp
+10
-10
oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy_ms.cpp
...tograd/gradient_funcs/sparse_softmax_cross_entropy_ms.cpp
+80
-0
oneflow/core/autograd/gradient_funcs/trunc.cpp
oneflow/core/autograd/gradient_funcs/trunc.cpp
+60
-0
oneflow/core/autograd/gradient_funcs/unfold.cpp
oneflow/core/autograd/gradient_funcs/unfold.cpp
+2
-2
oneflow/core/autograd/gradient_funcs/upsample.cpp
oneflow/core/autograd/gradient_funcs/upsample.cpp
+8
-15
oneflow/core/autograd/gradient_funcs/variance.cpp
oneflow/core/autograd/gradient_funcs/variance.cpp
+3
-0
oneflow/core/autograd/gradient_funcs/vector_matrix_product.cpp
...ow/core/autograd/gradient_funcs/vector_matrix_product.cpp
+94
-0
oneflow/core/autograd/higher_order_gradient_funcs/activation.cpp
.../core/autograd/higher_order_gradient_funcs/activation.cpp
+556
-0
oneflow/core/autograd/higher_order_gradient_funcs/avg_pool.cpp
...ow/core/autograd/higher_order_gradient_funcs/avg_pool.cpp
+158
-0
oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_loss.cpp
...higher_order_gradient_funcs/binary_cross_entropy_loss.cpp
+114
-0
oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_with_logits.cpp
...order_gradient_funcs/binary_cross_entropy_with_logits.cpp
+139
-0
oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp
...nt_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp
+116
-0
oneflow/core/autograd/higher_order_gradient_funcs/conv.cpp
oneflow/core/autograd/higher_order_gradient_funcs/conv.cpp
+228
-0
No files found.
Too many changes to show.
To preserve performance only
469 of 469+
files are displayed.
Plain diff
Email patch
oneflow/core/autograd/gradient_funcs/reshape.cpp
View file @
a715222c
...
@@ -28,7 +28,7 @@ struct ReshapeCaptureState : public AutoGradCaptureState {
...
@@ -28,7 +28,7 @@ struct ReshapeCaptureState : public AutoGradCaptureState {
DimVector
input_shape_vec
;
DimVector
input_shape_vec
;
};
};
class
Reshape
OpExpr
Grad
:
public
OpExprGradFunction
<
ReshapeCaptureState
>
{
class
ReshapeGrad
:
public
OpExprGradFunction
<
ReshapeCaptureState
>
{
public:
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
...
@@ -51,7 +51,34 @@ class ReshapeOpExprGrad : public OpExprGradFunction<ReshapeCaptureState> {
...
@@ -51,7 +51,34 @@ class ReshapeOpExprGrad : public OpExprGradFunction<ReshapeCaptureState> {
}
}
};
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"reshape"
,
ReshapeOpExprGrad
);
class
ReshapeLikeGrad
:
public
OpExprGradFunction
<
ReshapeCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ReshapeCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_OR_RETURN
(
!
inputs
.
at
(
1
)
->
requires_grad
())
<<
"ReshapeLikeOp's input[1] need not requires_grad."
;
ctx
->
input_shape_vec
=
inputs
.
at
(
0
)
->
shape
()
->
dim_vec
();
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ReshapeCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
Shape
shape
(
ctx
->
input_shape_vec
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Reshape
(
out_grads
.
at
(
0
),
shape
));
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"reshape"
,
ReshapeGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"reshape_like"
,
ReshapeLikeGrad
);
}
// namespace one
}
// namespace one
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/rms_norm.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
RMSNormCaptureState
:
public
AutoGradCaptureState
{
bool
x_requires_grad
=
false
;
bool
weight_requires_grad
=
false
;
int
x_index
=
-
1
;
int
inv_rms_index
=
-
1
;
int
weight_index
=
-
1
;
};
class
RMSNormGrad
:
public
OpExprGradFunction
<
RMSNormCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
RMSNormCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
RMSNormCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
};
Maybe
<
void
>
RMSNormGrad
::
Capture
(
RMSNormCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
// (x, [weight])
CHECK_GE_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_LE_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
// (y, inv_rms)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
// save x
ctx
->
x_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
// save weight
ctx
->
weight_requires_grad
=
false
;
if
(
inputs
.
size
()
>
1
)
{
ctx
->
weight_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
weight_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
}
// save inv_rms
if
(
ctx
->
x_requires_grad
||
ctx
->
weight_requires_grad
)
{
ctx
->
inv_rms_index
=
ctx
->
SaveTensorForBackward
(
outputs
[
1
]);
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
RMSNormGrad
::
Apply
(
const
RMSNormCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
// (x, inv_rms) or (x, weight, inv_rms)
const
auto
&
saved_tensors
=
ctx
->
SavedTensors
();
CHECK_GE_OR_RETURN
(
saved_tensors
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_LE_OR_RETURN
(
saved_tensors
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
// (dy, inv_rms_diff)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
const
auto
&
dy
=
out_grads
[
0
];
const
auto
&
x
=
saved_tensors
.
at
(
ctx
->
x_index
);
const
auto
&
inv_rms
=
saved_tensors
.
at
(
ctx
->
inv_rms_index
);
// (x_grad, weight_grad)
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
)
{
if
(
saved_tensors
.
size
()
==
3
)
{
const
auto
&
weight
=
saved_tensors
.
at
(
ctx
->
weight_index
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
RMSNormGrad
(
dy
,
x
,
inv_rms
,
weight
,
/*param_grad*/
false
));
}
else
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
RMSNormGrad
(
dy
,
x
,
inv_rms
,
/*weight*/
NullOpt
,
/*param_grad*/
false
));
}
}
if
(
ctx
->
weight_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
RMSNormGrad
(
dy
,
x
,
inv_rms
,
/*weight*/
NullOpt
,
/*param_grad*/
true
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"rms_norm"
,
RMSNormGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/scalar_floordiv.cpp
View file @
a715222c
...
@@ -13,15 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -13,15 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
one
{
namespace
one
{
// FloorDiv derivatives function isn't exists. (author: zhengzekang)
struct
ScalarFloorDivCaptureState
:
public
AutoGradCaptureState
{
struct
ScalarFloorDivCaptureState
:
public
AutoGradCaptureState
{};
bool
requires_grad
=
true
;
};
class
ScalarFloorDiv
:
public
OpExprGradFunction
<
ScalarFloorDivCaptureState
>
{
class
ScalarFloorDiv
:
public
OpExprGradFunction
<
ScalarFloorDivCaptureState
>
{
public:
public:
...
@@ -29,17 +31,20 @@ class ScalarFloorDiv : public OpExprGradFunction<ScalarFloorDivCaptureState> {
...
@@ -29,17 +31,20 @@ class ScalarFloorDiv : public OpExprGradFunction<ScalarFloorDivCaptureState> {
Maybe
<
void
>
Capture
(
ScalarFloorDivCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
Maybe
<
void
>
Capture
(
ScalarFloorDivCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
JUST
(
VectorAt
(
inputs
,
0
))
->
requires_grad
();
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
Maybe
<
void
>
Apply
(
const
ScalarFloorDivCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
Maybe
<
void
>
Apply
(
const
ScalarFloorDivCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
UNIMPLEMENTED_THEN_RETURN
()
<<
"RuntimeError: derivative for floor_divide is not implemented"
;
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
JUST
(
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
ZerosLike
(
JUST
(
VectorAt
(
out_grads
,
0
))));
}
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
private:
AttrMap
base_attrs_
;
};
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"scalar_floordiv"
,
ScalarFloorDiv
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"scalar_floordiv"
,
ScalarFloorDiv
);
...
...
oneflow/core/autograd/gradient_funcs/scalar_pow.cpp
View file @
a715222c
...
@@ -55,7 +55,6 @@ class ScalarPow : public OpExprGradFunction<ScalarPowCaptureState> {
...
@@ -55,7 +55,6 @@ class ScalarPow : public OpExprGradFunction<ScalarPowCaptureState> {
Maybe
<
void
>
Apply
(
const
ScalarPowCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
Maybe
<
void
>
Apply
(
const
ScalarPowCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
MutableAttrMap
attrs
;
in_grads
->
resize
(
1
);
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
if
(
ctx
->
requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ScalarPowGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
operand
));
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ScalarPowGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
operand
));
...
@@ -64,7 +63,6 @@ class ScalarPow : public OpExprGradFunction<ScalarPowCaptureState> {
...
@@ -64,7 +63,6 @@ class ScalarPow : public OpExprGradFunction<ScalarPowCaptureState> {
}
}
private:
private:
std
::
shared_ptr
<
OpExpr
>
grad_op_
;
AttrMap
base_attrs_
;
AttrMap
base_attrs_
;
};
};
...
@@ -100,7 +98,6 @@ class ScalarReversePow : public OpExprGradFunction<ScalarPowCaptureState> {
...
@@ -100,7 +98,6 @@ class ScalarReversePow : public OpExprGradFunction<ScalarPowCaptureState> {
Maybe
<
void
>
Apply
(
const
ScalarPowCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
Maybe
<
void
>
Apply
(
const
ScalarPowCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
const
auto
&
x
=
ctx
->
SavedTensors
()[
0
];
const
auto
&
x
=
ctx
->
SavedTensors
()[
0
];
MutableAttrMap
attrs
;
in_grads
->
resize
(
1
);
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
if
(
ctx
->
requires_grad
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
ScalarReversePowGrad
(
x
,
out_grads
[
0
],
ctx
->
operand
));
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
ScalarReversePowGrad
(
x
,
out_grads
[
0
],
ctx
->
operand
));
...
@@ -109,7 +106,6 @@ class ScalarReversePow : public OpExprGradFunction<ScalarPowCaptureState> {
...
@@ -109,7 +106,6 @@ class ScalarReversePow : public OpExprGradFunction<ScalarPowCaptureState> {
}
}
private:
private:
std
::
shared_ptr
<
OpExpr
>
grad_op_
;
AttrMap
base_attrs_
;
AttrMap
base_attrs_
;
};
};
...
...
oneflow/core/autograd/gradient_funcs/scalar_truncdiv.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace
oneflow
{
namespace
one
{
struct
ScalarTruncDivCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
true
;
};
class
ScalarTruncDiv
:
public
OpExprGradFunction
<
ScalarTruncDivCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ScalarTruncDivCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ScalarTruncDivCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
JUST
(
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
ZerosLike
(
JUST
(
VectorAt
(
out_grads
,
0
))));
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"scalar_truncdiv"
,
ScalarTruncDiv
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/slice.cpp
View file @
a715222c
...
@@ -98,7 +98,7 @@ class SliceUpdate : public OpExprGradFunction<SliceUpdateCaptureState> {
...
@@ -98,7 +98,7 @@ class SliceUpdate : public OpExprGradFunction<SliceUpdateCaptureState> {
if
(
ctx
->
requires_grad_ref
)
{
if
(
ctx
->
requires_grad_ref
)
{
ctx
->
value_shape
=
*
(
inputs
[
1
]
->
shape
());
ctx
->
value_shape
=
*
(
inputs
[
1
]
->
shape
());
if
(
inputs
[
1
]
->
is_
consistent
())
{
ctx
->
value_sbp
=
JUST
(
inputs
[
1
]
->
nd_sbp
());
}
if
(
inputs
[
1
]
->
is_
global
())
{
ctx
->
value_sbp
=
JUST
(
inputs
[
1
]
->
nd_sbp
());
}
}
}
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
...
@@ -114,8 +114,7 @@ class SliceUpdate : public OpExprGradFunction<SliceUpdateCaptureState> {
...
@@ -114,8 +114,7 @@ class SliceUpdate : public OpExprGradFunction<SliceUpdateCaptureState> {
JUST
(
out_grads
[
0
]
->
device
())));
JUST
(
out_grads
[
0
]
->
device
())));
}
else
{
}
else
{
const
auto
&
parallel_desc
=
JUST
(
out_grads
[
0
]
->
parallel_desc
());
const
auto
&
parallel_desc
=
JUST
(
out_grads
[
0
]
->
parallel_desc
());
zeros
=
zeros
=
JUST
(
functional
::
GlobalConstant
(
ctx
->
value_shape
,
0
,
out_grads
[
0
]
->
dtype
(),
JUST
(
functional
::
ConsistentConstant
(
ctx
->
value_shape
,
0
,
out_grads
[
0
]
->
dtype
(),
parallel_desc
,
*
JUST
(
GetSbpList
(
ctx
->
value_sbp
))));
parallel_desc
,
*
JUST
(
GetSbpList
(
ctx
->
value_sbp
))));
}
}
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
SliceUpdate
(
out_grads
[
0
],
zeros
,
ctx
->
start
,
ctx
->
stop
,
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
SliceUpdate
(
out_grads
[
0
],
zeros
,
ctx
->
start
,
ctx
->
stop
,
...
...
oneflow/core/autograd/gradient_funcs/smooth_l1_loss.cpp
View file @
a715222c
...
@@ -22,7 +22,8 @@ namespace oneflow {
...
@@ -22,7 +22,8 @@ namespace oneflow {
namespace
one
{
namespace
one
{
struct
SmoothL1LossCaptureState
:
public
AutoGradCaptureState
{
struct
SmoothL1LossCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
bool
input_requires_grad
=
false
;
bool
target_requires_grad
=
false
;
float
beta
=
0.0
;
float
beta
=
0.0
;
};
};
...
@@ -37,13 +38,13 @@ class SmoothL1Loss : public OpExprGradFunction<SmoothL1LossCaptureState> {
...
@@ -37,13 +38,13 @@ class SmoothL1Loss : public OpExprGradFunction<SmoothL1LossCaptureState> {
Maybe
<
void
>
Capture
(
SmoothL1LossCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
Maybe
<
void
>
Capture
(
SmoothL1LossCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
// prediction
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
// prediction
ctx
->
input_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
// input
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
// label
ctx
->
target_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
// target
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
// input
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
// target
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
beta
=
JUST
(
composed_attrs
.
GetAttr
<
float
>
(
"beta"
));
ctx
->
beta
=
JUST
(
composed_attrs
.
GetAttr
<
float
>
(
"beta"
));
...
@@ -52,15 +53,15 @@ class SmoothL1Loss : public OpExprGradFunction<SmoothL1LossCaptureState> {
...
@@ -52,15 +53,15 @@ class SmoothL1Loss : public OpExprGradFunction<SmoothL1LossCaptureState> {
Maybe
<
void
>
Apply
(
const
SmoothL1LossCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
Maybe
<
void
>
Apply
(
const
SmoothL1LossCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
ctx
->
SavedTensors
().
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
2
);
in_grads
->
resize
(
2
);
const
auto
&
input
=
ctx
->
SavedTensors
().
at
(
0
);
const
auto
&
target
=
ctx
->
SavedTensors
().
at
(
1
);
const
auto
&
grad
=
JUST
(
functional
::
SmoothL1LossGrad
(
out_grads
[
0
],
input
,
target
,
ctx
->
beta
));
const
auto
&
prediction
=
ctx
->
SavedTensors
().
at
(
0
);
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
0
]
=
grad
;
}
const
auto
&
label
=
ctx
->
SavedTensors
().
at
(
1
);
if
(
ctx
->
target_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
Negative
(
grad
));
}
in_grads
->
at
(
0
)
=
JUST
(
functional
::
SmoothL1LossGrad
(
out_grads
.
at
(
0
),
prediction
,
label
,
ctx
->
beta
));
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
...
...
oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp
View file @
a715222c
...
@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
...
@@ -50,10 +51,10 @@ Maybe<void> SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyCaptureS
...
@@ -50,10 +51,10 @@ Maybe<void> SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyCaptureS
const
AttrMap
&
attrs
)
const
{
const
AttrMap
&
attrs
)
const
{
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
depth
=
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"depth"
));
ctx
->
depth
=
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"depth"
));
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
0
));
// prob
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
outputs
,
0
)
));
// prob
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
// label
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
1
)
));
// label
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
...
@@ -61,15 +62,14 @@ Maybe<void> SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyCapt
...
@@ -61,15 +62,14 @@ Maybe<void> SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyCapt
const
TensorTuple
&
out_grads
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
const
auto
&
dy
=
out_grads
.
at
(
1
);
const
auto
&
dy
=
JUST
(
VectorAt
(
out_grads
,
1
));
const
auto
&
prob
=
ctx
->
SavedTensors
().
at
(
0
);
const
auto
&
prob
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
0
));
const
auto
&
label
=
ctx
->
SavedTensors
().
at
(
1
);
const
auto
&
label
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
1
));
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
<
int64_t
>
(
"depth"
,
ctx
->
depth
));
// SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not
// SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not
// require gradient.
// require gradient.
in_grads
->
resize
(
2
);
in_grads
->
resize
(
2
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
SparseSoftmaxCrossEntropyGrad
(
dy
,
prob
,
label
,
ctx
->
depth
));
JUST
(
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
SparseSoftmaxCrossEntropyGrad
(
dy
,
prob
,
label
,
ctx
->
depth
));
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
...
...
oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy_ms.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
SparseSoftmaxCrossEntropyMsCaptureState
:
public
AutoGradCaptureState
{
int64_t
depth
=
0
;
};
class
SparseSoftmaxCrossEntropyMs
:
public
OpExprGradFunction
<
SparseSoftmaxCrossEntropyMsCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
SparseSoftmaxCrossEntropyMsCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
SparseSoftmaxCrossEntropyMsCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
SparseSoftmaxCrossEntropyMs
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SparseSoftmaxCrossEntropyMs
::
Capture
(
SparseSoftmaxCrossEntropyMsCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
depth
=
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"depth"
));
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
outputs
,
0
)));
// prob
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
1
)));
// label
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SparseSoftmaxCrossEntropyMs
::
Apply
(
const
SparseSoftmaxCrossEntropyMsCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
const
auto
&
dy
=
JUST
(
VectorAt
(
out_grads
,
1
));
const
auto
&
prob
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
0
));
const
auto
&
label
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
1
));
// SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not
// require gradient.
in_grads
->
resize
(
2
);
JUST
(
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
SparseSoftmaxCrossEntropyMsGrad
(
dy
,
prob
,
label
,
ctx
->
depth
));
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"sparse_softmax_cross_entropy_ms"
,
SparseSoftmaxCrossEntropyMs
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/trunc.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
TruncCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
};
class
Trunc
:
public
OpExprGradFunction
<
TruncCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
TruncCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
TruncCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
};
Maybe
<
void
>
Trunc
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Trunc
::
Capture
(
TruncCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Trunc
::
Apply
(
const
TruncCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ZerosLike
(
out_grads
.
at
(
0
)));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"trunc"
,
Trunc
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/unfold.cpp
View file @
a715222c
...
@@ -73,8 +73,8 @@ Maybe<void> Unfold::Apply(const UnfoldInterpState* ctx, const TensorTuple& out_g
...
@@ -73,8 +73,8 @@ Maybe<void> Unfold::Apply(const UnfoldInterpState* ctx, const TensorTuple& out_g
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
in_grads
->
resize
(
1
);
in_grads
->
at
(
0
)
=
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Fold
(
out_grads
.
at
(
0
),
ctx
->
data_format
,
ctx
->
output_size
,
ctx
->
kernel_size
,
JUST
(
functional
::
Fold
(
out_grads
.
at
(
0
),
ctx
->
output_size
,
ctx
->
kernel_size
,
ctx
->
dilation_rate
,
ctx
->
dilation_rate
,
ctx
->
padding
,
ctx
->
strides
));
ctx
->
padding
,
ctx
->
strides
,
ctx
->
data_format
));
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
...
...
oneflow/core/autograd/gradient_funcs/upsample.cpp
View file @
a715222c
...
@@ -100,7 +100,7 @@ class UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DCaptureStat
...
@@ -100,7 +100,7 @@ class UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DCaptureStat
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
if
(
ba
se_attrs
_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
()
)
{
if
(
compo
se
d
_attrs
.
Has
(
"output_size"
))
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
...
@@ -112,7 +112,6 @@ class UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DCaptureStat
...
@@ -112,7 +112,6 @@ class UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DCaptureStat
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleNearest2DGrad
(
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleNearest2DGrad
(
...
@@ -151,7 +150,7 @@ class UpsampleBilinear2D : public OpExprGradFunction<UpsampleBilinear2DCaptureSt
...
@@ -151,7 +150,7 @@ class UpsampleBilinear2D : public OpExprGradFunction<UpsampleBilinear2DCaptureSt
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
if
(
ba
se_attrs
_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
()
)
{
if
(
compo
se
d
_attrs
.
Has
(
"output_size"
))
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
...
@@ -163,7 +162,6 @@ class UpsampleBilinear2D : public OpExprGradFunction<UpsampleBilinear2DCaptureSt
...
@@ -163,7 +162,6 @@ class UpsampleBilinear2D : public OpExprGradFunction<UpsampleBilinear2DCaptureSt
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleBilinear2DGrad
(
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleBilinear2DGrad
(
...
@@ -200,7 +198,7 @@ class UpsampleLinear1D : public OpExprGradFunction<UpsampleLinear1DCaptureState>
...
@@ -200,7 +198,7 @@ class UpsampleLinear1D : public OpExprGradFunction<UpsampleLinear1DCaptureState>
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
scale_factor
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"scale_factor"
));
ctx
->
scale_factor
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"scale_factor"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
if
(
ba
se_attrs
_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
()
)
{
if
(
compo
se
d
_attrs
.
Has
(
"output_size"
))
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
...
@@ -212,7 +210,6 @@ class UpsampleLinear1D : public OpExprGradFunction<UpsampleLinear1DCaptureState>
...
@@ -212,7 +210,6 @@ class UpsampleLinear1D : public OpExprGradFunction<UpsampleLinear1DCaptureState>
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleLinear1DGrad
(
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleLinear1DGrad
(
...
@@ -247,7 +244,7 @@ class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureStat
...
@@ -247,7 +244,7 @@ class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureStat
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
scale_factor
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"scale_factor"
));
ctx
->
scale_factor
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"scale_factor"
));
if
(
ba
se_attrs
_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
()
)
{
if
(
compo
se
d
_attrs
.
Has
(
"output_size"
))
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
...
@@ -259,7 +256,6 @@ class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureStat
...
@@ -259,7 +256,6 @@ class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureStat
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
...
@@ -298,7 +294,7 @@ class UpsampleBicubic2D : public OpExprGradFunction<UpsampleBicubic2DCaptureStat
...
@@ -298,7 +294,7 @@ class UpsampleBicubic2D : public OpExprGradFunction<UpsampleBicubic2DCaptureStat
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
if
(
ba
se_attrs
_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
()
)
{
if
(
compo
se
d
_attrs
.
Has
(
"output_size"
))
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
...
@@ -310,7 +306,6 @@ class UpsampleBicubic2D : public OpExprGradFunction<UpsampleBicubic2DCaptureStat
...
@@ -310,7 +306,6 @@ class UpsampleBicubic2D : public OpExprGradFunction<UpsampleBicubic2DCaptureStat
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleBicubic2DGrad
(
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleBicubic2DGrad
(
...
@@ -348,7 +343,7 @@ class UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DCaptureStat
...
@@ -348,7 +343,7 @@ class UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DCaptureStat
ctx
->
depth_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"depth_scale"
));
ctx
->
depth_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"depth_scale"
));
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
if
(
ba
se_attrs
_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
()
)
{
if
(
compo
se
d
_attrs
.
Has
(
"output_size"
))
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
...
@@ -360,7 +355,6 @@ class UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DCaptureStat
...
@@ -360,7 +355,6 @@ class UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DCaptureStat
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleNearest3DGrad
(
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleNearest3DGrad
(
...
@@ -401,7 +395,7 @@ class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCapture
...
@@ -401,7 +395,7 @@ class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCapture
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
if
(
ba
se_attrs
_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
()
)
{
if
(
compo
se
d
_attrs
.
Has
(
"output_size"
))
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
...
@@ -413,7 +407,6 @@ class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCapture
...
@@ -413,7 +407,6 @@ class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCapture
TensorTuple
*
in_grads
)
const
override
{
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleTrilinear3DGrad
(
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleTrilinear3DGrad
(
...
@@ -430,4 +423,4 @@ class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCapture
...
@@ -430,4 +423,4 @@ class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DCapture
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"upsample_trilinear_3d"
,
UpsampleTrilinear3D
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"upsample_trilinear_3d"
,
UpsampleTrilinear3D
);
}
// namespace one
}
// namespace one
}
// namespace oneflow
}
// namespace oneflow
\ No newline at end of file
oneflow/core/autograd/gradient_funcs/variance.cpp
View file @
a715222c
...
@@ -68,6 +68,9 @@ Maybe<void> Variance::Apply(const VarianceState* ctx, const TensorTuple& out_gra
...
@@ -68,6 +68,9 @@ Maybe<void> Variance::Apply(const VarianceState* ctx, const TensorTuple& out_gra
TensorTuple
*
in_grads
)
const
{
TensorTuple
*
in_grads
)
const
{
// TODO(): replace it using kernel
// TODO(): replace it using kernel
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
DataType
data_type
=
x
->
dtype
()
->
data_type
();
CHECK_NE_OR_RETURN
(
data_type
,
DataType
::
kBFloat16
)
<<
Error
::
RuntimeError
()
<<
"Variance op not support backward for bfloat16 yet!"
;
size_t
correction
=
ctx
->
unbiased
?
1
:
0
;
size_t
correction
=
ctx
->
unbiased
?
1
:
0
;
size_t
elem_cnt
=
1
;
size_t
elem_cnt
=
1
;
CHECK_OR_RETURN
(
ctx
->
axis
.
size
()
>
0
)
CHECK_OR_RETURN
(
ctx
->
axis
.
size
()
>
0
)
...
...
oneflow/core/autograd/gradient_funcs/vector_matrix_product.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace
oneflow
{
namespace
one
{
struct
VectorMatrixProductCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad_a
=
false
;
bool
requires_grad_b
=
false
;
size_t
a_index
=
0
;
size_t
b_index
=
1
;
};
class
VectorMatrixProduct
:
public
OpExprGradFunction
<
VectorMatrixProductCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
VectorMatrixProductCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
VectorMatrixProductCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
protected:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
VectorMatrixProduct
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
)
<<
"fw_op_expr should not be null. "
;
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
VectorMatrixProduct
::
Capture
(
VectorMatrixProductCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad_a
=
JUST
(
VectorAt
(
inputs
,
0
))
->
requires_grad
();
ctx
->
requires_grad_b
=
JUST
(
VectorAt
(
inputs
,
1
))
->
requires_grad
();
if
(
!
ctx
->
requires_grad_a
&&
!
ctx
->
requires_grad_b
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
if
(
ctx
->
requires_grad_a
)
{
ctx
->
b_index
=
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
1
)));
// input b
}
if
(
ctx
->
requires_grad_b
)
{
ctx
->
a_index
=
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
0
)));
// input a
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
VectorMatrixProduct
::
Apply
(
const
VectorMatrixProductCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad_a
&&
!
ctx
->
requires_grad_b
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
)
<<
"Out grad size should be equal to 1. "
;
in_grads
->
resize
(
2
);
if
(
ctx
->
requires_grad_a
)
{
const
auto
&
input_b
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
b_index
));
JUST
(
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
VectorMatrixProductGradA
(
JUST
(
VectorAt
(
out_grads
,
0
)),
input_b
));
}
if
(
ctx
->
requires_grad_b
)
{
const
auto
&
input_a
=
JUST
(
oneflow
::
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
a_index
));
JUST
(
VectorAt
(
*
in_grads
,
1
))
=
JUST
(
functional
::
VectorMatrixProductGradB
(
JUST
(
VectorAt
(
out_grads
,
0
)),
input_a
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"vector_matrix_product"
,
VectorMatrixProduct
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/activation.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cstddef>
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
BaseActivationGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
x_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
};
typedef
Maybe
<
one
::
Tensor
>
(
*
NoParamActivationBwFunc
)(
const
std
::
shared_ptr
<
one
::
Tensor
>&
,
const
std
::
shared_ptr
<
one
::
Tensor
>&
);
template
<
NoParamActivationBwFunc
BwFunc
,
NoParamActivationBwFunc
BwBwFunc
>
class
NoParamActivationGradGrad
:
public
OpExprGradFunction
<
BaseActivationGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
BaseActivationGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// dy, x
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
x_requires_grad
&&
!
ctx
->
grad_requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
if
(
ctx
->
x_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
BaseActivationGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
if
(
ctx
->
x_requires_grad
)
{
const
auto
&
grad
=
ctx
->
SavedTensors
().
at
(
1
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
Mul
(
out_grads
.
at
(
0
),
JUST
(
BwBwFunc
(
x
,
grad
))));
}
if
(
ctx
->
grad_requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
BwFunc
(
out_grads
.
at
(
0
),
x
));
}
return
Maybe
<
void
>::
Ok
();
}
};
#define INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS(op_type_name, op_cls) \
class op_cls##GradGradCls final \
: public NoParamActivationGradGrad<functional::op_cls##Grad, functional::op_cls##GradGrad> { \
}; \
REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls);
// first order backward param: (dy, x)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS
(
"mish_grad"
,
Mish
)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS
(
"gelu_grad"
,
Gelu
)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS
(
"silu_grad"
,
Silu
)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS
(
"selu_grad"
,
Selu
)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS
(
"softsign_grad"
,
SoftSign
)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS
(
"hardsigmoid_grad"
,
HardSigmoid
)
INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS
(
"hardswish_grad"
,
HardSwish
)
#undef INSTANTIAT_AND_REGISTER_NOPARAM_ACTIVATION_CLASS
struct
HardShrinkGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
y_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
double
lambd
=
0.5
;
};
class
HardShrinkGradGrad
:
public
OpExprGradFunction
<
HardShrinkGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
HardShrinkGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// y, dy
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
y_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
!
ctx
->
y_requires_grad
&&
!
ctx
->
grad_requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
lambd
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"lambd"
));
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
HardShrinkGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
if
(
ctx
->
y_requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ZerosLike
(
out_grads
.
at
(
0
)));
}
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
HardShrinkGrad
(
y
,
out_grads
.
at
(
0
),
ctx
->
lambd
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
SoftShrinkGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
y_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
double
alpha
=
0.5
;
};
class
SoftShrinkGradGrad
:
public
OpExprGradFunction
<
SoftShrinkGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
SoftShrinkGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// y, dy
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
y_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
!
ctx
->
y_requires_grad
&&
!
ctx
->
grad_requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
alpha
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"alpha"
));
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
SoftShrinkGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
if
(
ctx
->
y_requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ZerosLike
(
out_grads
.
at
(
0
)));
}
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
SoftShrinkGrad
(
y
,
out_grads
.
at
(
0
),
ctx
->
alpha
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
ReluGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
y_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
};
class
ReluGradGrad
:
public
OpExprGradFunction
<
ReluGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ReluGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// dy, y
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
y_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ReluGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
if
(
ctx
->
y_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
ZerosLike
(
out_grads
.
at
(
0
)));
}
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ReluGrad
(
out_grads
.
at
(
0
),
y
));
}
return
Maybe
<
void
>::
Ok
();
}
};
struct
LeakyReluGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
x_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
float
alpha
=
0.01
;
};
class
LeakyReluGradGrad
:
public
OpExprGradFunction
<
LeakyReluGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
LeakyReluGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// x, dy
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
!
ctx
->
x_requires_grad
&&
!
ctx
->
grad_requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
alpha
=
JUST
(
composed_attrs
.
GetAttr
<
float
>
(
"alpha"
));
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
LeakyReluGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ZerosLike
(
out_grads
.
at
(
0
)));
}
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
LeakyReluGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
alpha
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
SoftplusGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
x_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
double
beta
=
1.0
;
double
threshold
=
20.0
;
};
class
SoftplusGradGrad
:
public
OpExprGradFunction
<
SoftplusGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
SoftplusGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// x, dy
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
!
ctx
->
x_requires_grad
&&
!
ctx
->
grad_requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
beta
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"beta"
));
ctx
->
threshold
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"threshold"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
if
(
ctx
->
x_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
SoftplusGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
if
(
ctx
->
x_requires_grad
)
{
const
auto
&
grad
=
ctx
->
SavedTensors
().
at
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Mul
(
out_grads
.
at
(
0
),
JUST
(
functional
::
SoftplusGradGrad
(
x
,
grad
,
ctx
->
beta
,
ctx
->
threshold
))));
}
if
(
ctx
->
grad_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
SoftplusGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
beta
,
ctx
->
threshold
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
HardTanhGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
y_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
double
min_val
=
-
1.0
;
double
max_val
=
1.0
;
};
class
HardTanhGradGrad
:
public
OpExprGradFunction
<
HardTanhGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
HardTanhGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// y, dy
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
y_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
!
ctx
->
y_requires_grad
&&
!
ctx
->
grad_requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
min_val
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"min_val"
));
ctx
->
max_val
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"max_val"
));
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
HardTanhGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
if
(
ctx
->
y_requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ZerosLike
(
out_grads
.
at
(
0
)));
}
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
HardTanhGrad
(
y
,
out_grads
.
at
(
0
),
ctx
->
min_val
,
ctx
->
max_val
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
EluGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
x_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
double
alpha
=
1.0
;
};
class
EluGradGrad
:
public
OpExprGradFunction
<
EluGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
EluGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// x, dy
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
alpha
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"alpha"
));
if
(
!
ctx
->
x_requires_grad
&&
!
ctx
->
grad_requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
if
(
ctx
->
x_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
EluGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
if
(
ctx
->
x_requires_grad
)
{
const
auto
&
grad
=
ctx
->
SavedTensors
().
at
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Mul
(
out_grads
.
at
(
0
),
JUST
(
functional
::
EluGradGrad
(
x
,
grad
,
ctx
->
alpha
))));
}
if
(
ctx
->
grad_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
EluGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
alpha
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
class
CeluGradGrad
:
public
EluGradGrad
{
public:
Maybe
<
void
>
Apply
(
const
EluGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
if
(
ctx
->
x_requires_grad
)
{
const
auto
&
grad
=
ctx
->
SavedTensors
().
at
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
CeluGradGrad
(
x
,
JUST
(
functional
::
Mul
(
out_grads
.
at
(
0
),
(
grad
))),
ctx
->
alpha
));
}
if
(
ctx
->
grad_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
CeluGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
alpha
));
}
return
Maybe
<
void
>::
Ok
();
}
};
struct
PReluGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
grad_requires_grad
=
false
;
bool
input_requires_grad
=
false
;
bool
alpha_requires_grad
=
false
;
size_t
grad_index
=
0
;
size_t
input_index
=
1
;
size_t
alpha_index
=
2
;
};
class
PReluGradGrad
:
public
OpExprGradFunction
<
PReluGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
PReluGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// dy, x, alpha
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
// grad
ctx
->
input_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
// input
ctx
->
alpha_requires_grad
=
inputs
.
at
(
2
)
->
requires_grad
();
// alpha
ctx
->
input_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
ctx
->
alpha_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
));
ctx
->
grad_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
PReluGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
3
);
const
auto
&
input
=
ctx
->
SavedTensors
().
at
(
ctx
->
input_index
);
const
auto
&
alpha
=
ctx
->
SavedTensors
().
at
(
ctx
->
alpha_index
);
const
auto
&
grad
=
ctx
->
SavedTensors
().
at
(
ctx
->
grad_index
);
const
auto
&
grad_for_input
=
out_grads
.
at
(
0
);
const
auto
&
grad_for_alpha
=
out_grads
.
at
(
1
);
const
auto
&
condition
=
JUST
(
functional
::
ScalarLogicalLess
(
input
,
Scalar
(
0.0
)));
const
auto
&
zero_grad
=
JUST
(
functional
::
ZerosLike
(
alpha
));
// alpha can broadcast to input
if
(
ctx
->
grad_requires_grad
)
{
auto
input_mul_grad
=
JUST
(
functional
::
Mul
(
alpha
,
grad_for_input
));
auto
alpha_mul_grad
=
JUST
(
functional
::
Mul
(
input
,
grad_for_alpha
));
auto
result
=
JUST
(
functional
::
Add
(
input_mul_grad
,
alpha_mul_grad
,
/*alpha=*/
Scalar
(
1.0
),
/*inplace*/
false
));
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Where
(
condition
,
result
,
grad_for_input
));
}
if
(
ctx
->
input_requires_grad
)
{
auto
result
=
JUST
(
functional
::
Mul
(
grad
,
grad_for_alpha
));
in_grads
->
at
(
1
)
=
JUST
(
functional
::
Where
(
condition
,
result
,
zero_grad
));
}
if
(
ctx
->
alpha_requires_grad
)
{
auto
result
=
JUST
(
functional
::
Mul
(
grad
,
grad_for_input
));
in_grads
->
at
(
2
)
=
JUST
(
functional
::
Where
(
condition
,
result
,
zero_grad
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
std
::
shared_ptr
<
OpExpr
>
grad_op_
;
};
struct
ThresholdGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
x_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
double
threshold
=
0.0
;
};
class
ThresholdGradGrad
:
public
OpExprGradFunction
<
ThresholdGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ThresholdGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// x, dy
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
!
ctx
->
x_requires_grad
&&
!
ctx
->
grad_requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
threshold
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"threshold_val"
));
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ThresholdGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ZerosLike
(
out_grads
.
at
(
0
)));
}
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
ThresholdGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
threshold
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"relu_grad"
,
ReluGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"elu_grad"
,
EluGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"celu_grad"
,
CeluGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"prelu_grad"
,
PReluGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"hardshrink_grad"
,
HardShrinkGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"softshrink_grad"
,
SoftShrinkGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"leaky_relu_grad"
,
LeakyReluGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"hardtanh_grad"
,
HardTanhGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"threshold_grad"
,
ThresholdGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"softplus_grad"
,
SoftplusGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/avg_pool.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace
oneflow
{
namespace
one
{
struct
AdaptiveAvgPoolNDGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
input_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
std
::
vector
<
int64_t
>
pool_output_size
;
};
template
<
int
ndims
>
class
AdaptiveAvgPoolNdNdGradGrad
:
public
OpExprGradFunction
<
AdaptiveAvgPoolNDGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
AdaptiveAvgPoolNDGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// dy, x
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
input_requires_grad
=
inputs
[
1
]
->
requires_grad
();
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
grad_shape
=
*
inputs
[
0
]
->
shape
();
if
(
ndims
==
1
)
{
ctx
->
pool_output_size
=
{
grad_shape
[
grad_shape
.
size
()
-
1
]};
}
else
if
(
ndims
==
2
)
{
ctx
->
pool_output_size
=
{
grad_shape
[
grad_shape
.
size
()
-
2
],
grad_shape
[
grad_shape
.
size
()
-
1
]};
}
else
if
(
ndims
==
3
)
{
ctx
->
pool_output_size
=
{
grad_shape
[
grad_shape
.
size
()
-
3
],
grad_shape
[
grad_shape
.
size
()
-
2
],
grad_shape
[
grad_shape
.
size
()
-
1
]};
}
else
{
UNIMPLEMENTED_THEN_RETURN
();
}
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
AdaptiveAvgPoolNDGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
2
);
if
(
ctx
->
grad_requires_grad
)
{
if
(
ndims
==
1
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
AdaptiveAvgPool1D
(
out_grads
[
0
],
ctx
->
pool_output_size
));
}
else
if
(
ndims
==
2
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
AdaptiveAvgPool2D
(
out_grads
[
0
],
ctx
->
pool_output_size
));
}
else
if
(
ndims
==
3
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
AdaptiveAvgPool3D
(
out_grads
[
0
],
ctx
->
pool_output_size
));
}
else
{
UNIMPLEMENTED_THEN_RETURN
();
}
}
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
ZerosLike
(
out_grads
[
0
]));
}
return
Maybe
<
void
>::
Ok
();
}
};
struct
AvgPoolGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
input_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
std
::
string
data_format
;
std
::
vector
<
int32_t
>
padding
;
std
::
vector
<
int32_t
>
kernel_size
;
std
::
vector
<
int32_t
>
stride
;
bool
ceil_mode
=
false
;
bool
count_include_pad
=
false
;
int32_t
divisor_override
=
0
;
};
class
AvgPoolNdGradGrad
:
public
OpExprGradFunction
<
AvgPoolGradGradCaptureState
>
{
public:
virtual
~
AvgPoolNdGradGrad
()
=
default
;
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
AvgPoolGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// dy, x
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
input_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
padding
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"padding"
));
ctx
->
kernel_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"kernel_size"
));
ctx
->
stride
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"stride"
));
ctx
->
ceil_mode
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"ceil_mode"
));
ctx
->
count_include_pad
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"count_include_pad"
));
ctx
->
divisor_override
=
JUST
(
composed_attrs
.
GetAttr
<
int32_t
>
(
"divisor_override"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
AvgPoolGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
2
);
if
(
ctx
->
grad_requires_grad
)
{
int32_t
ndims
=
ctx
->
kernel_size
.
size
();
const
auto
pool_op
=
(
ndims
==
1
?
functional
::
AvgPool1D
:
(
ndims
==
2
?
functional
::
AvgPool2D
:
(
ndims
==
3
?
functional
::
AvgPool3D
:
nullptr
)));
CHECK_NOTNULL_OR_RETURN
(
pool_op
);
// NOLINT(maybe-need-error-msg)
(
*
in_grads
)[
0
]
=
JUST
(
pool_op
(
out_grads
[
0
],
ctx
->
kernel_size
,
ctx
->
stride
,
ctx
->
padding
,
ctx
->
ceil_mode
,
ctx
->
count_include_pad
,
ctx
->
divisor_override
,
ctx
->
data_format
));
}
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
ZerosLike
(
out_grads
[
0
]));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"avg_pool_1d_grad"
,
AvgPoolNdGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"avg_pool_2d_grad"
,
AvgPoolNdGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"avg_pool_3d_grad"
,
AvgPoolNdGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"adaptive_avg_pool1d_grad"
,
AdaptiveAvgPoolNdNdGradGrad
<
1
>
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"adaptive_avg_pool2d_grad"
,
AdaptiveAvgPoolNdNdGradGrad
<
2
>
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"adaptive_avg_pool3d_grad"
,
AdaptiveAvgPoolNdNdGradGrad
<
3
>
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_loss.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
BinaryCrossEntropyGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
grad_requires_grad
=
false
;
bool
input_requires_grad
=
false
;
bool
target_requires_grad
=
false
;
bool
has_weight
=
false
;
};
class
BinaryCrossEntropyGradGrad
:
public
OpExprGradFunction
<
BinaryCrossEntropyGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
BinaryCrossEntropyGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
BinaryCrossEntropyGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
};
Maybe
<
void
>
BinaryCrossEntropyGradGrad
::
Init
(
const
OpExpr
&
op
)
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyGradGrad
::
Capture
(
BinaryCrossEntropyGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
// dy, input, target[, weight]
CHECK_OR_RETURN
(
inputs
.
size
()
>=
3
&&
inputs
.
size
()
<=
4
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
input_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
target_requires_grad
=
inputs
[
2
]
->
requires_grad
();
ctx
->
has_weight
=
inputs
.
size
()
==
4
;
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
// grad
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
// input
ctx
->
SaveTensorForBackward
(
inputs
[
2
]);
// target
if
(
ctx
->
has_weight
)
{
ctx
->
SaveTensorForBackward
(
inputs
[
3
]);
// weight
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyGradGrad
::
Apply
(
const
BinaryCrossEntropyGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
ctx
->
SavedTensors
().
size
(),
3
+
ctx
->
has_weight
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
3
+
ctx
->
has_weight
);
const
auto
&
grad
=
ctx
->
SavedTensors
()[
0
];
const
auto
&
input
=
ctx
->
SavedTensors
()[
1
];
const
auto
&
target
=
ctx
->
SavedTensors
()[
2
];
// dx = grad * [-target/input + (1-target)/(1-input)]
// grad_for_grad = out_grad * [-target/input + (1-target)/(1-input)]
// grad_for_input = out_grad * grad * [target/(input*input) + (1-target)/((1-input)*(1-input))]
// = out_grad * grad * [(input*input-2*input*target+target)/(input*(1-input))^2]
// grad_for_target = out_grad * grad * [1/(input*(1-input))]
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
weight
=
ctx
->
has_weight
?
Optional
<
one
::
Tensor
>
(
ctx
->
SavedTensors
()[
3
])
:
NullOpt
;
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
BinaryCrossEntropyLossGrad
(
out_grads
[
0
],
input
,
target
,
weight
));
}
if
(
ctx
->
input_requires_grad
)
{
auto
one_sub_input
=
JUST
(
functional
::
ScalarSub
(
1
,
input
,
/*alpha=*/
1
));
auto
input_mul_target
=
JUST
(
functional
::
Mul
(
input
,
target
));
auto
numerator
=
JUST
(
functional
::
sequence_function
(
functional
::
Square
)
.
then
(
std
::
bind
(
functional
::
Sub
,
std
::
placeholders
::
_1
,
input_mul_target
,
/*alpha=*/
2
,
/*inplace=*/
false
))
.
then
([
&
target
](
const
std
::
shared_ptr
<
Tensor
>&
in
)
{
return
functional
::
Add
(
in
,
target
,
/*alpha=*/
1
,
/*inplace=*/
false
);
})
.
call
(
input
));
auto
res
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
functional
::
Square
)
.
then
(
std
::
bind
(
functional
::
Div
,
numerator
,
std
::
placeholders
::
_1
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
out_grads
[
0
]))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
grad
))
.
call
(
input
,
one_sub_input
));
(
*
in_grads
)[
1
]
=
ctx
->
has_weight
?
JUST
(
functional
::
Mul
(
ctx
->
SavedTensors
()[
3
],
res
))
:
res
;
}
if
(
ctx
->
target_requires_grad
)
{
auto
input_sub_one
=
JUST
(
functional
::
ScalarAdd
(
-
1
,
input
,
/*alpha=*/
1
));
auto
res
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
std
::
bind
(
functional
::
LogGrad
,
std
::
placeholders
::
_1
,
out_grads
[
0
]))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
grad
))
.
call
(
input
,
input_sub_one
));
(
*
in_grads
)[
2
]
=
ctx
->
has_weight
?
JUST
(
functional
::
Mul
(
ctx
->
SavedTensors
()[
3
],
res
))
:
res
;
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"binary_cross_entropy_grad"
,
BinaryCrossEntropyGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_with_logits.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
BinaryCrossEntropyWithLogitsGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
grad_requires_grad
=
false
;
bool
input_requires_grad
=
false
;
bool
target_requires_grad
=
false
;
bool
has_weight
=
false
;
bool
has_pos_weight
=
false
;
};
class
BinaryCrossEntropyWithLogitsGradGrad
:
public
OpExprGradFunction
<
BinaryCrossEntropyWithLogitsGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
BinaryCrossEntropyWithLogitsGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
BinaryCrossEntropyWithLogitsGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
BinaryCrossEntropyWithLogitsGradGrad
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyWithLogitsGradGrad
::
Capture
(
BinaryCrossEntropyWithLogitsGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
// dy, input, target[, weight][, pos_weight]
CHECK_OR_RETURN
(
inputs
.
size
()
>=
3
&&
inputs
.
size
()
<=
5
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
input_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
target_requires_grad
=
inputs
[
2
]
->
requires_grad
();
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
has_pos_weight
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"has_pos_weight"
));
ctx
->
has_weight
=
inputs
.
size
()
==
5
||
(
inputs
.
size
()
==
4
&&
!
ctx
->
has_pos_weight
);
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
// grad
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
// input
ctx
->
SaveTensorForBackward
(
inputs
[
2
]);
// target
if
(
inputs
.
size
()
==
4
)
{
ctx
->
SaveTensorForBackward
(
inputs
[
3
]);
// weight or pos_weight
}
if
(
inputs
.
size
()
==
5
)
{
ctx
->
SaveTensorForBackward
(
inputs
[
3
]);
// weight
ctx
->
SaveTensorForBackward
(
inputs
[
4
]);
// pos_weight
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyWithLogitsGradGrad
::
Apply
(
const
BinaryCrossEntropyWithLogitsGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
ctx
->
SavedTensors
().
size
(),
3
+
ctx
->
has_weight
+
ctx
->
has_pos_weight
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
3
+
ctx
->
has_weight
+
ctx
->
has_pos_weight
);
const
auto
&
grad
=
ctx
->
SavedTensors
()[
0
];
const
auto
&
input
=
ctx
->
SavedTensors
()[
1
];
const
auto
&
target
=
ctx
->
SavedTensors
()[
2
];
const
size_t
pos_weight_index
=
ctx
->
has_weight
?
4
:
3
;
const
auto
&
weight
=
ctx
->
has_weight
?
Optional
<
one
::
Tensor
>
(
ctx
->
SavedTensors
()[
3
])
:
NullOpt
;
const
auto
&
pos_weight
=
ctx
->
has_pos_weight
?
Optional
<
one
::
Tensor
>
(
ctx
->
SavedTensors
()[
pos_weight_index
])
:
NullOpt
;
// dx = grad * weight * (-target*(1-input.sigmoid())*pos_weight + input.sigmoid()*(1-target))
// grad_for_input = out_grad * grad * weight * sig * (1-sig) * [pos_weight * target + 1 - target]
// grad_for_target = -out_grad * grad * weight * [pos_weight + sig - pos_weight * sig]
if
(
ctx
->
grad_requires_grad
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsLossGrad
(
out_grads
[
0
],
input
,
target
,
weight
,
pos_weight
));
}
if
(
ctx
->
input_requires_grad
)
{
auto
res
=
JUST
(
functional
::
sequence_function
(
functional
::
Sigmoid
)
.
then
(
std
::
bind
(
functional
::
SigmoidGrad
,
std
::
placeholders
::
_1
,
grad
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
out_grads
[
0
]))
.
call
(
input
));
if
(
ctx
->
has_pos_weight
)
{
res
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
([](
const
std
::
shared_ptr
<
Tensor
>&
input
)
{
return
functional
::
ScalarAdd
(
1
,
input
,
/*alpha=*/
Scalar
(
1
));
})
.
then
(
std
::
bind
(
functional
::
Sub
,
std
::
placeholders
::
_1
,
target
,
/*alpha=*/
1
,
/*inplace=*/
false
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
res
))
.
call
(
JUST
(
pos_weight
),
target
));
}
if
(
ctx
->
has_weight
)
{
res
=
JUST
(
functional
::
Mul
(
res
,
JUST
(
weight
)));
}
(
*
in_grads
)[
1
]
=
res
;
}
if
(
ctx
->
target_requires_grad
)
{
auto
res
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
functional
::
Negative
)
.
call
(
out_grads
[
0
],
grad
));
if
(
ctx
->
has_pos_weight
)
{
auto
sig
=
JUST
(
functional
::
Sigmoid
(
input
));
auto
one_sub_sig
=
JUST
(
functional
::
ScalarSub
(
1
,
sig
,
/*alpha=*/
1
));
res
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
([
&
sig
](
const
std
::
shared_ptr
<
Tensor
>&
input
)
{
return
functional
::
Add
(
input
,
sig
,
/*alpha=*/
Scalar
(
1
),
/*inplace=*/
false
);
})
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
res
))
.
call
(
one_sub_sig
,
JUST
(
pos_weight
)));
}
if
(
ctx
->
has_weight
)
{
res
=
JUST
(
functional
::
Mul
(
res
,
JUST
(
weight
)));
}
(
*
in_grads
)[
2
]
=
res
;
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"binary_cross_entropy_with_logits_grad"
,
BinaryCrossEntropyWithLogitsGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
grad_requires_grad
=
false
;
bool
input_requires_grad
=
false
;
bool
target_requires_grad
=
false
;
size_t
grad_index
=
0
;
size_t
input_index
=
0
;
size_t
target_index
=
0
;
};
class
BinaryCrossEntropyWithLogitsReduceMeanGradGrad
:
public
OpExprGradFunction
<
BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
};
Maybe
<
void
>
BinaryCrossEntropyWithLogitsReduceMeanGradGrad
::
Init
(
const
OpExpr
&
op
)
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyWithLogitsReduceMeanGradGrad
::
Capture
(
BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
// dy, input, target
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
input_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
target_requires_grad
=
inputs
[
2
]
->
requires_grad
();
if
(
ctx
->
input_requires_grad
||
ctx
->
target_requires_grad
)
{
ctx
->
grad_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
// grad
}
if
(
ctx
->
input_requires_grad
||
ctx
->
grad_requires_grad
)
{
ctx
->
input_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
// input
}
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
target_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
2
]);
// target
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyWithLogitsReduceMeanGradGrad
::
Apply
(
const
BinaryCrossEntropyWithLogitsReduceMeanGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
3
);
// dx = grad * weight * (input.sigmoid() - target)
// grad_for_input = out_grad * grad * weight * sig * (1-sig)
// grad_for_target = -out_grad * grad * weight
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
input
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
input_index
));
const
auto
&
target
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
target_index
));
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
sequence_function
(
functional
::
Sigmoid
)
.
then
(
std
::
bind
(
functional
::
Sub
,
std
::
placeholders
::
_1
,
target
,
/*alpha=*/
1
,
/*inplace=*/
false
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
out_grads
[
0
]))
.
then
(
std
::
bind
(
functional
::
ReduceMean
,
std
::
placeholders
::
_1
,
std
::
vector
<
int32_t
>
{},
/*keepdim=*/
false
))
.
call
(
input
));
}
if
(
ctx
->
input_requires_grad
)
{
const
auto
&
grad
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
grad_index
));
const
auto
&
input
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
input_index
));
const
auto
&
mean_grad
=
JUST
(
functional
::
ScalarMul
(
1.0
/
out_grads
[
0
]
->
nelement
(),
grad
));
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
sequence_function
(
functional
::
Sigmoid
)
.
then
(
std
::
bind
(
functional
::
SigmoidGrad
,
std
::
placeholders
::
_1
,
out_grads
[
0
]))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
mean_grad
))
.
call
(
input
));
}
if
(
ctx
->
target_requires_grad
)
{
const
auto
&
grad
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
grad_index
));
const
auto
&
mean_grad
=
JUST
(
functional
::
ScalarMul
(
1.0
/
out_grads
[
0
]
->
nelement
(),
grad
));
(
*
in_grads
)[
2
]
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
functional
::
Negative
)
.
call
(
out_grads
[
0
],
mean_grad
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"binary_cross_entropy_with_logits_reduce_mean_grad"
,
BinaryCrossEntropyWithLogitsReduceMeanGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/conv.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
ConvDataGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
w_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
size_t
w_index
=
0
;
size_t
grad_index
=
0
;
std
::
string
data_format
;
std
::
vector
<
int32_t
>
padding_before
;
std
::
vector
<
int32_t
>
kernel_size
;
std
::
vector
<
int32_t
>
strides
;
std
::
vector
<
int32_t
>
dilation_rate
;
int32_t
groups
=
0
;
};
class
ConvDataGradGrad
:
public
OpExprGradFunction
<
ConvDataGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
ConvDataGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
ConvDataGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
ConvDataGradGrad
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
ConvDataGradGrad
::
Capture
(
ConvDataGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
// input: dy, w, x_like, [add to output]
// output: dx
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
w_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
w_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
if
(
ctx
->
w_requires_grad
)
{
ctx
->
grad_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
padding_before
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"padding_before"
));
ctx
->
kernel_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"kernel_size"
));
ctx
->
strides
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"strides"
));
ctx
->
dilation_rate
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"dilation_rate"
));
ctx
->
groups
=
JUST
(
composed_attrs
.
GetAttr
<
int32_t
>
(
"groups"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
ConvDataGradGrad
::
Apply
(
const
ConvDataGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
in_grads
->
resize
(
3
);
size_t
num_spatial_dims
=
ctx
->
kernel_size
.
size
();
// first order forward: ConvND
// x * w = y ( * => convolution)
// first order backward:
// x_grad = y_grad * w.rot180 (y.shape * w.shape -> x.shape) call ConvDataGrad
// w_grad = x * y_grad (x.shape * y.shape -> w.shape) call ConvFilterGrad
// second order forward (first order backward): ConvDataGrad
// y_grad * w.rot180 = x_grad
// second order forward:
// w_grad_grad = out_grads_x * y_grad (x.shape * y.shape -> w.shape) call ConvFilterGrad
// grad_for_y_grad = out_grads_x * w (x.shape * w.shape -> y.shape) call ConvND
// w_grad_grad
if
(
ctx
->
w_requires_grad
)
{
const
auto
&
grad
=
ctx
->
SavedTensors
().
at
(
ctx
->
grad_index
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
ConvFilterGrad
(
grad
,
out_grads
.
at
(
0
),
num_spatial_dims
,
ctx
->
kernel_size
,
ctx
->
strides
,
ctx
->
padding_before
,
ctx
->
dilation_rate
,
ctx
->
groups
,
ctx
->
data_format
));
}
// grad_for_y_grad
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
w
=
ctx
->
SavedTensors
().
at
(
ctx
->
w_index
);
const
int32_t
ndims
=
ctx
->
kernel_size
.
size
();
const
auto
conv_op
=
(
ndims
==
1
?
functional
::
Conv1d
:
(
ndims
==
2
?
functional
::
Conv2d
:
(
ndims
==
3
?
functional
::
Conv3d
:
nullptr
)));
CHECK_NOTNULL_OR_RETURN
(
conv_op
);
// NOLINT(maybe-need-error-msg)
in_grads
->
at
(
0
)
=
JUST
(
conv_op
(
out_grads
.
at
(
0
),
w
,
Optional
<
Tensor
>
(),
ctx
->
strides
,
ctx
->
padding_before
,
ctx
->
dilation_rate
,
ctx
->
groups
,
ctx
->
data_format
));
}
return
Maybe
<
void
>::
Ok
();
}
struct
ConvFilterGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
x_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
size_t
x_index
=
0
;
size_t
grad_index
=
0
;
std
::
string
data_format
;
std
::
vector
<
int32_t
>
padding_before
;
std
::
vector
<
int32_t
>
kernel_size
;
std
::
vector
<
int32_t
>
strides
;
std
::
vector
<
int32_t
>
dilation_rate
;
int32_t
groups
=
0
;
};
class
ConvFilterGradGrad
:
public
OpExprGradFunction
<
ConvFilterGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
ConvFilterGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
ConvFilterGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
ConvFilterGradGrad
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
ConvFilterGradGrad
::
Capture
(
ConvFilterGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
// input: dy, x
// output: dw
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
if
(
ctx
->
x_requires_grad
)
{
ctx
->
grad_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
padding_before
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"padding_before"
));
ctx
->
kernel_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"kernel_size"
));
ctx
->
strides
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"strides"
));
ctx
->
dilation_rate
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"dilation_rate"
));
ctx
->
groups
=
JUST
(
composed_attrs
.
GetAttr
<
int32_t
>
(
"groups"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
ConvFilterGradGrad
::
Apply
(
const
ConvFilterGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
in_grads
->
resize
(
2
);
size_t
num_spatial_dims
=
ctx
->
kernel_size
.
size
();
// first order forward: ConvND
// x * w = y ( * => convolution)
// first order backward:
// x_grad = y_grad * w.rot180 (y.shape * w.shape -> x.shape) call ConvDataGrad
// w_grad = x * y_grad (x.shape * y.shape -> w.shape) call ConvFilterGrad
// second order forward (first order backward): ConvFilterGrad
// x * y_grad = w_grad
// second order backward:
// x_grad_grad = out_grads_w * y_grad.rot180 (y.shape * w.shape -> x.shape) call ConvDataGrad
// grad_for_y_grad = x * out_grads_w (x.shape * w.shape -> y.shape) call ConvND
// x_grad_grad
if
(
ctx
->
x_requires_grad
)
{
const
auto
&
grad
=
ctx
->
SavedTensors
().
at
(
ctx
->
grad_index
);
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
ConvDataGrad
(
grad
,
out_grads
.
at
(
0
),
JUST
(
x
->
detach
()),
num_spatial_dims
,
ctx
->
kernel_size
,
ctx
->
strides
,
ctx
->
padding_before
,
ctx
->
dilation_rate
,
ctx
->
groups
,
ctx
->
data_format
));
}
// grad_for_y_grad
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
const
int32_t
ndims
=
ctx
->
kernel_size
.
size
();
const
auto
conv_op
=
(
ndims
==
1
?
functional
::
Conv1d
:
(
ndims
==
2
?
functional
::
Conv2d
:
(
ndims
==
3
?
functional
::
Conv3d
:
nullptr
)));
CHECK_NOTNULL_OR_RETURN
(
conv_op
);
// NOLINT(maybe-need-error-msg)
in_grads
->
at
(
0
)
=
JUST
(
conv_op
(
x
,
out_grads
.
at
(
0
),
Optional
<
Tensor
>
(),
ctx
->
strides
,
ctx
->
padding_before
,
ctx
->
dilation_rate
,
ctx
->
groups
,
ctx
->
data_format
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"conv_data_grad"
,
ConvDataGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"conv_filter_grad"
,
ConvFilterGradGrad
);
}
// namespace one
}
// namespace oneflow
Prev
1
…
8
9
10
11
12
13
14
15
16
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment