Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
469
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1347 additions
and
61 deletions
+1347
-61
oneflow/core/autograd/higher_order_gradient_funcs/div.cpp
oneflow/core/autograd/higher_order_gradient_funcs/div.cpp
+102
-0
oneflow/core/autograd/higher_order_gradient_funcs/kl_div_loss.cpp
...core/autograd/higher_order_gradient_funcs/kl_div_loss.cpp
+89
-0
oneflow/core/autograd/higher_order_gradient_funcs/log_softmax.cpp
...core/autograd/higher_order_gradient_funcs/log_softmax.cpp
+87
-0
oneflow/core/autograd/higher_order_gradient_funcs/math_unary_op.cpp
...re/autograd/higher_order_gradient_funcs/math_unary_op.cpp
+133
-0
oneflow/core/autograd/higher_order_gradient_funcs/matmul.cpp
oneflow/core/autograd/higher_order_gradient_funcs/matmul.cpp
+85
-0
oneflow/core/autograd/higher_order_gradient_funcs/max_pool.cpp
...ow/core/autograd/higher_order_gradient_funcs/max_pool.cpp
+68
-0
oneflow/core/autograd/higher_order_gradient_funcs/nll_loss.cpp
...ow/core/autograd/higher_order_gradient_funcs/nll_loss.cpp
+93
-0
oneflow/core/autograd/higher_order_gradient_funcs/pow.cpp
oneflow/core/autograd/higher_order_gradient_funcs/pow.cpp
+179
-0
oneflow/core/autograd/higher_order_gradient_funcs/scalar_pow.cpp
.../core/autograd/higher_order_gradient_funcs/scalar_pow.cpp
+153
-0
oneflow/core/autograd/higher_order_gradient_funcs/slice.cpp
oneflow/core/autograd/higher_order_gradient_funcs/slice.cpp
+67
-0
oneflow/core/autograd/higher_order_gradient_funcs/smooth_l1_loss.cpp
...e/autograd/higher_order_gradient_funcs/smooth_l1_loss.cpp
+103
-0
oneflow/core/autograd/higher_order_gradient_funcs/softmax.cpp
...low/core/autograd/higher_order_gradient_funcs/softmax.cpp
+87
-0
oneflow/core/boxing/asymmetric_broadcast.cpp
oneflow/core/boxing/asymmetric_broadcast.cpp
+12
-14
oneflow/core/boxing/boxing_interpreter_status.h
oneflow/core/boxing/boxing_interpreter_status.h
+6
-8
oneflow/core/boxing/ccl_boxing_function.cpp
oneflow/core/boxing/ccl_boxing_function.cpp
+52
-10
oneflow/core/boxing/cuda_copy_boxing_interpreter.cpp
oneflow/core/boxing/cuda_copy_boxing_interpreter.cpp
+4
-10
oneflow/core/boxing/eager_boxing_interpreter.cpp
oneflow/core/boxing/eager_boxing_interpreter.cpp
+1
-1
oneflow/core/boxing/eager_boxing_interpreter_mgr.cpp
oneflow/core/boxing/eager_boxing_interpreter_mgr.cpp
+8
-1
oneflow/core/boxing/flatten_hierarchy.cpp
oneflow/core/boxing/flatten_hierarchy.cpp
+3
-2
oneflow/core/boxing/generic_symmetric_nd_sbp_boxing.cpp
oneflow/core/boxing/generic_symmetric_nd_sbp_boxing.cpp
+15
-15
No files found.
Too many changes to show.
To preserve performance only
469 of 469+
files are displayed.
Plain diff
Email patch
oneflow/core/autograd/higher_order_gradient_funcs/div.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <functional>
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
DivGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
y_requires_grad
=
false
;
bool
z_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
size_t
y_index
=
0
;
size_t
z_index
=
1
;
size_t
grad_index
=
2
;
};
class
DivGradGrad
:
public
OpExprGradFunction
<
DivGradGradCaptureState
>
{
// div_grad = -x/(y*y)*dz = -z/y*dz
// div_grad_y = out_grad * z*dz/(y*y)
// div_grad_z = out_grad * -dz/y
// div_grad_dz = out_grad * -z/y
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
DivGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// dz, z, y
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
z_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
ctx
->
y_requires_grad
=
inputs
.
at
(
2
)
->
requires_grad
();
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
));
if
(
ctx
->
y_requires_grad
||
ctx
->
grad_requires_grad
)
{
ctx
->
z_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
if
(
ctx
->
y_requires_grad
||
ctx
->
z_requires_grad
)
{
ctx
->
grad_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
DivGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
3
);
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
z
=
ctx
->
SavedTensors
().
at
(
ctx
->
z_index
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
functional
::
Negative
)
.
then
(
std
::
bind
(
functional
::
Div
,
std
::
placeholders
::
_1
,
y
))
.
call
(
out_grads
.
at
(
0
),
z
));
}
if
(
ctx
->
z_requires_grad
)
{
const
auto
&
grad
=
ctx
->
SavedTensors
().
at
(
ctx
->
grad_index
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
functional
::
Negative
)
.
then
(
std
::
bind
(
functional
::
Div
,
std
::
placeholders
::
_1
,
y
))
.
call
(
out_grads
.
at
(
0
),
grad
));
}
if
(
ctx
->
y_requires_grad
)
{
const
auto
&
z
=
ctx
->
SavedTensors
().
at
(
ctx
->
z_index
);
const
auto
&
grad
=
ctx
->
SavedTensors
().
at
(
ctx
->
grad_index
);
in_grads
->
at
(
2
)
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
std
::
bind
(
functional
::
BroadcastReduceSumLike
,
std
::
placeholders
::
_1
,
y
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
out_grads
.
at
(
0
)))
.
then
(
std
::
bind
(
functional
::
Div
,
std
::
placeholders
::
_1
,
JUST
(
functional
::
Square
(
y
))))
.
call
(
z
,
grad
));
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_div_grad"
,
DivGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/kl_div_loss.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
KLDivLossGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
grad_requires_grad
=
false
;
bool
input_requires_grad
=
false
;
bool
target_requires_grad
=
false
;
bool
log_target
=
false
;
size_t
input_index
=
0
;
size_t
target_index
=
0
;
};
class
KLDivLossGradGrad
:
public
OpExprGradFunction
<
KLDivLossGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
KLDivLossGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
KLDivLossGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
KLDivLossGradGrad
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
KLDivLossGradGrad
::
Capture
(
KLDivLossGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
// grad, input, target
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
input_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
target_requires_grad
=
inputs
[
2
]
->
requires_grad
();
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
log_target
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"log_target"
));
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
input_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
// input
ctx
->
target_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
2
]);
// target
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
KLDivLossGradGrad
::
Apply
(
const
KLDivLossGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
3
);
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
input
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
input_index
));
const
auto
&
target
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
target_index
));
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
KLDivLossGrad
(
out_grads
[
0
],
input
,
target
,
ctx
->
log_target
));
}
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
ZerosLike
(
out_grads
[
0
]));
}
if
(
ctx
->
target_requires_grad
)
{
(
*
in_grads
)[
2
]
=
JUST
(
functional
::
ZerosLike
(
out_grads
[
0
]));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"kl_div_loss_grad"
,
KLDivLossGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/log_softmax.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
LogSoftmaxGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
y_requires_grad
=
false
;
bool
dy_requires_grad
=
false
;
};
class
LogSoftmaxGradGrad
:
public
OpExprGradFunction
<
LogSoftmaxGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
LogSoftmaxGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
LogSoftmaxGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
};
Maybe
<
void
>
LogSoftmaxGradGrad
::
Init
(
const
OpExpr
&
op
)
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
LogSoftmaxGradGrad
::
Capture
(
LogSoftmaxGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
// y, dy
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
y_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
dy_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
if
(
ctx
->
y_requires_grad
)
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
LogSoftmaxGradGrad
::
Apply
(
const
LogSoftmaxGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
in_grads
->
resize
(
2
);
const
auto
&
y
=
ctx
->
SavedTensors
()[
0
];
const
std
::
vector
<
int32_t
>
reduce_axis
{
static_cast
<
int32_t
>
(
y
->
ndim
()
-
1
)};
if
(
ctx
->
y_requires_grad
)
{
const
auto
&
dy
=
ctx
->
SavedTensors
()[
1
];
in_grads
->
at
(
0
)
=
JUST
(
functional
::
sequence_function
(
functional
::
ReduceSum
)
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
out_grads
[
0
]))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
JUST
(
functional
::
Exp
(
y
))))
.
then
(
functional
::
Negative
)
.
call
(
dy
,
reduce_axis
,
true
));
}
if
(
ctx
->
dy_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
sequence_function
(
functional
::
Exp
)
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
out_grads
[
0
]))
.
then
(
std
::
bind
(
functional
::
ReduceSum
,
std
::
placeholders
::
_1
,
reduce_axis
,
/*keepdim=*/
true
))
.
then
(
std
::
bind
(
functional
::
Sub
,
out_grads
[
0
],
std
::
placeholders
::
_1
,
/*alpha=*/
1
,
/*inplace=*/
false
))
.
call
(
y
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"log_softmax_grad"
,
LogSoftmaxGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/math_unary_op.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
UnaryMathGradGradState
:
public
AutoGradCaptureState
{
bool
input_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
};
typedef
Maybe
<
one
::
Tensor
>
(
*
UnaryBwFunc
)(
const
std
::
shared_ptr
<
one
::
Tensor
>&
,
const
std
::
shared_ptr
<
one
::
Tensor
>&
);
template
<
UnaryBwFunc
BwFunc
,
UnaryBwFunc
BwBwFunc
>
class
UnaryMathGradGrad
:
public
OpExprGradFunction
<
UnaryMathGradGradState
>
{
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
UnaryMathGradGradState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
input_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
if
(
ctx
->
input_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
UnaryMathGradGradState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
const
auto
&
input
=
ctx
->
SavedTensors
()[
0
];
if
(
ctx
->
input_requires_grad
)
{
const
auto
&
grad
=
ctx
->
SavedTensors
()[
1
];
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
Mul
(
out_grads
[
0
],
JUST
(
BwBwFunc
(
input
,
grad
))));
}
if
(
ctx
->
grad_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
BwFunc
(
input
,
out_grads
[
0
]));
}
return
Maybe
<
void
>::
Ok
();
}
};
template
<
UnaryBwFunc
BwFunc
>
class
UnaryMathGradGradWithZeroDDX
:
public
OpExprGradFunction
<
UnaryMathGradGradState
>
{
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
UnaryMathGradGradState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
input_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
UnaryMathGradGradState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
const
auto
&
input
=
ctx
->
SavedTensors
()[
0
];
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
ZerosLike
(
input
));
}
if
(
ctx
->
grad_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
BwFunc
(
input
,
out_grads
[
0
]));
}
return
Maybe
<
void
>::
Ok
();
}
};
// TODO: Lgamma, first order backward unimplemented
#define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_X_FUNC_SEQ \
OF_PP_MAKE_TUPLE_SEQ("sin_grad", Sin) \
OF_PP_MAKE_TUPLE_SEQ("cos_grad", Cos) \
OF_PP_MAKE_TUPLE_SEQ("tan_grad", Tan) \
OF_PP_MAKE_TUPLE_SEQ("sinh_grad", Sinh) \
OF_PP_MAKE_TUPLE_SEQ("cosh_grad", Cosh) \
OF_PP_MAKE_TUPLE_SEQ("tanh_grad", Tanh) \
OF_PP_MAKE_TUPLE_SEQ("asin_grad", Asin) \
OF_PP_MAKE_TUPLE_SEQ("acos_grad", Acos) \
OF_PP_MAKE_TUPLE_SEQ("atan_grad", Atan) \
OF_PP_MAKE_TUPLE_SEQ("asinh_grad", Asinh) \
OF_PP_MAKE_TUPLE_SEQ("acosh_grad", Acosh) \
OF_PP_MAKE_TUPLE_SEQ("atanh_grad", Atanh) \
OF_PP_MAKE_TUPLE_SEQ("erf_grad", Erf) \
OF_PP_MAKE_TUPLE_SEQ("erfc_grad", Erfc) \
OF_PP_MAKE_TUPLE_SEQ("exp_grad", Exp) \
OF_PP_MAKE_TUPLE_SEQ("expm1_grad", Expm1) \
OF_PP_MAKE_TUPLE_SEQ("log_grad", Log) \
OF_PP_MAKE_TUPLE_SEQ("log_sigmoid_grad", LogSigmoid) \
OF_PP_MAKE_TUPLE_SEQ("log2_grad", Log2) \
OF_PP_MAKE_TUPLE_SEQ("log1p_grad", Log1p) \
OF_PP_MAKE_TUPLE_SEQ("reciprocal_grad", Reciprocal) \
OF_PP_MAKE_TUPLE_SEQ("reciprocal_no_nan_grad", ReciprocalNoNan) \
OF_PP_MAKE_TUPLE_SEQ("rsqrt_grad", Rsqrt) \
OF_PP_MAKE_TUPLE_SEQ("sqrt_grad", Sqrt) \
OF_PP_MAKE_TUPLE_SEQ("square_grad", Square)
#define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_Y_FUNC_SEQ OF_PP_MAKE_TUPLE_SEQ("sigmoid_grad", Sigmoid)
#define MATH_UNARY_ELEMENTWISE_GRAD_GRAD_ZERO_DDX_FUNC_SEQ OF_PP_MAKE_TUPLE_SEQ("abs_grad", Abs)
#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS(op_type_name, op_cls) \
class op_cls##GradGradCls final \
: public UnaryMathGradGrad<functional::op_cls##Grad, functional::op_cls##GradGrad> {}; \
REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls);
OF_PP_FOR_EACH_TUPLE
(
INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS
,
MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_X_FUNC_SEQ
);
OF_PP_FOR_EACH_TUPLE
(
INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_CLASS
,
MATH_UNARY_ELEMENTWISE_GRAD_GRAD_DY_Y_FUNC_SEQ
);
#define INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_ZERO_DDX_CLASS(op_type_name, op_cls) \
class op_cls##GradGradCls final \
: public UnaryMathGradGradWithZeroDDX<functional::op_cls##Grad> {}; \
REGISTER_OP_EXPR_GRAD_FUNCTION(op_type_name, op_cls##GradGradCls);
OF_PP_FOR_EACH_TUPLE
(
INSTANTIAT_AND_REGISTER_UNARY_MATHOP_GRAD_GRAD_ZERO_DDX_CLASS
,
MATH_UNARY_ELEMENTWISE_GRAD_GRAD_ZERO_DDX_FUNC_SEQ
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/matmul.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
BroadcastMatmulGradBGradCaptureState
:
public
AutoGradCaptureState
{
bool
a_requires_grad
=
false
;
bool
b_requires_grad
=
false
;
size_t
a_index
=
0
;
size_t
b_index
=
1
;
double
alpha
=
1.0
;
};
class
BroadcastMatmulGradBGrad
:
public
OpExprGradFunction
<
BroadcastMatmulGradBGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
};
Maybe
<
void
>
Capture
(
BroadcastMatmulGradBGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
a_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
b_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
ctx
->
a_requires_grad
)
{
ctx
->
b_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
if
(
ctx
->
b_requires_grad
)
{
ctx
->
a_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
alpha
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"alpha"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
BroadcastMatmulGradBGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
// for matmul: input_a[dims..., m, k] * input_b[k, n] -> [dims..., m, n]
// if forward: BroadcastMatmulGradB(input_a, JUST(VectorAt(out_grads, 0)), ctx->alpha))
// then: a.shape = [dims..., m, k], b.shape = [dims..., m, n], grad.shape = [k, n]
// if forward: BroadcastMatmulGradB(JUST(VectorAt(out_grads, 0)), input_a, ctx->alpha))
// then: a.shape = [dims..., m, n], b.shape = [dims..., m, k], grad.shape = [n, k]
if
(
ctx
->
a_requires_grad
)
{
const
auto
&
b
=
ctx
->
SavedTensors
()[
ctx
->
b_index
];
in_grads
->
at
(
0
)
=
JUST
(
functional
::
MatMul
(
b
,
out_grads
.
at
(
0
),
false
,
true
,
ctx
->
alpha
));
}
if
(
ctx
->
b_requires_grad
)
{
const
auto
&
a
=
ctx
->
SavedTensors
()[
ctx
->
a_index
];
in_grads
->
at
(
1
)
=
JUST
(
functional
::
MatMul
(
a
,
out_grads
.
at
(
0
),
false
,
false
,
ctx
->
alpha
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_matmul_grad_b"
,
BroadcastMatmulGradBGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/max_pool.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace
oneflow
{
namespace
one
{
struct
MaxPoolGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
grad_requires_grad
=
false
;
bool
input_requires_grad
=
false
;
};
template
<
int
ndims
>
class
MaxPoolNdGradGrad
:
public
OpExprGradFunction
<
MaxPoolGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
MaxPoolGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// dy, x, indice
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
input_requires_grad
=
inputs
[
1
]
->
requires_grad
();
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
[
2
]);
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
MaxPoolGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
3
);
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
indices
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
0
));
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
MaxPoolNdGradGrad
(
out_grads
[
0
],
indices
,
ndims
));
}
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
ZerosLike
(
out_grads
[
0
]));
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"max_pool_1d_grad"
,
MaxPoolNdGradGrad
<
1
>
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"max_pool_2d_grad"
,
MaxPoolNdGradGrad
<
2
>
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"max_pool_3d_grad"
,
MaxPoolNdGradGrad
<
3
>
);
// REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool1d_grad", MaxPoolNdGradGrad<1>);
// REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool2d_grad", MaxPoolNdGradGrad<2>);
// REGISTER_OP_EXPR_GRAD_FUNCTION("adaptive_max_pool3d_grad", MaxPoolNdGradGrad<3>);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/nll_loss.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
namespace
oneflow
{
namespace
one
{
struct
NLLCaptureState
:
public
AutoGradCaptureState
{
bool
input_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
bool
has_weight
=
false
;
int64_t
ignore_index
=
-
100
;
};
class
NLLLossGradGrad
:
public
OpExprGradFunction
<
NLLCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
NLLCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
NLLCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
NLLLossGradGrad
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
NLLLossGradGrad
::
Capture
(
NLLCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
// dy, input, target[, weight]
CHECK_OR_RETURN
(
inputs
.
size
()
>=
3
&&
inputs
.
size
()
<=
4
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
input_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
has_weight
=
inputs
.
size
()
==
4
;
if
(
ctx
->
grad_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
[
2
]);
if
(
ctx
->
has_weight
)
{
ctx
->
SaveTensorForBackward
(
inputs
[
3
]);
}
// weight
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
ignore_index
=
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"ignore_index"
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
NLLLossGradGrad
::
Apply
(
const
NLLCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
3
+
ctx
->
has_weight
);
if
(
ctx
->
grad_requires_grad
)
{
const
auto
&
target
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
0
));
if
(
ctx
->
has_weight
)
{
auto
weight
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
1
));
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
NLLLoss
(
out_grads
[
0
],
target
,
weight
,
ctx
->
ignore_index
,
"none"
));
}
else
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
NLLLoss
(
out_grads
[
0
],
target
,
NullOpt
,
ctx
->
ignore_index
,
"none"
));
}
}
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
1
]
=
JUST
(
functional
::
ZerosLike
(
out_grads
[
0
]));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"nll_grad"
,
NLLLossGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/pow.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <functional>
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
PowXGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
x_requires_grad
=
false
;
bool
y_requires_grad
=
false
;
bool
dz_requires_grad
=
false
;
size_t
x_index
=
0
;
size_t
y_index
=
1
;
size_t
dz_index
=
2
;
};
class
PowXGradGrad
:
public
OpExprGradFunction
<
PowXGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
PowXGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// x, y, dz
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
y_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
ctx
->
dz_requires_grad
=
inputs
.
at
(
2
)
->
requires_grad
();
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
if
(
ctx
->
x_requires_grad
||
ctx
->
y_requires_grad
)
{
ctx
->
dz_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
PowXGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
3
);
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
// dx = y * x^(y-1) * dz
// grad_for_x = out_grads * dz * y * [x^(y-1)]'
// grad_for_y = out_grads * dz * [x^(y-1) * (1 + y * ln(x))]
// grad_for_dz = out_grads * y * x^(y-1)
if
(
ctx
->
x_requires_grad
||
ctx
->
y_requires_grad
)
{
const
auto
&
dz
=
ctx
->
SavedTensors
().
at
(
ctx
->
dz_index
);
const
auto
&
y_sub_one
=
JUST
(
functional
::
ScalarSub
(
y
,
1
,
/*alpha=*/
1
,
/*inplace=*/
false
));
if
(
ctx
->
x_requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
sequence_function
(
functional
::
PowXGrad
)
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
y
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
dz
))
.
call
(
x
,
y_sub_one
,
out_grads
.
at
(
0
)));
}
if
(
ctx
->
y_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
sequence_function
(
functional
::
Log
)
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
y
))
.
then
([](
const
std
::
shared_ptr
<
Tensor
>&
input
)
{
return
functional
::
ScalarAdd
(
1
,
input
,
/*alpha=*/
1
);
})
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
JUST
(
functional
::
Pow
(
x
,
y_sub_one
))))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
dz
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
out_grads
.
at
(
0
)))
.
call
(
x
));
}
}
if
(
ctx
->
dz_requires_grad
)
{
in_grads
->
at
(
2
)
=
JUST
(
functional
::
PowXGrad
(
x
,
y
,
out_grads
.
at
(
0
)));
}
return
Maybe
<
void
>::
Ok
();
}
};
struct
PowYGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
x_requires_grad
=
false
;
bool
y_requires_grad
=
false
;
bool
dz_requires_grad
=
false
;
size_t
x_index
=
0
;
size_t
y_index
=
1
;
size_t
dz_index
=
2
;
size_t
dy_index
=
3
;
};
class
PowYGradGrad
:
public
OpExprGradFunction
<
PowYGradGradCaptureState
>
{
public:
// dy = x^y*ln(x)*dz = z*ln(x)*dz
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
PowYGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// x, y, dz
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
y_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
ctx
->
dz_requires_grad
=
inputs
.
at
(
2
)
->
requires_grad
();
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
if
(
ctx
->
x_requires_grad
||
ctx
->
y_requires_grad
)
{
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
if
(
ctx
->
x_requires_grad
)
{
ctx
->
dz_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
));
}
if
(
ctx
->
y_requires_grad
)
{
ctx
->
dy_index
=
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
PowYGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
3
);
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
// dy = x^y * ln(x) * dz = z * ln(x) * dz
// grad_for_x = out_grads * dz * [x^(y-1) * (1 + y * ln(x))]
// grad_for_y = out_grads * dy' = out_grads * dy * ln(x)
// grad_for_dz = out_grads * x^y * ln(x)
if
(
ctx
->
x_requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
const
auto
&
dz
=
ctx
->
SavedTensors
().
at
(
ctx
->
dz_index
);
const
auto
&
y_sub_one
=
JUST
(
functional
::
ScalarSub
(
y
,
1
,
/*alpha=*/
1
,
/*inplace=*/
false
));
in_grads
->
at
(
0
)
=
JUST
(
functional
::
sequence_function
(
functional
::
Log
)
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
y
))
.
then
([](
const
std
::
shared_ptr
<
Tensor
>&
input
)
{
return
functional
::
ScalarAdd
(
1
,
input
,
/*alpha=*/
1
);
})
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
JUST
(
functional
::
Pow
(
x
,
y_sub_one
))))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
dz
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
out_grads
.
at
(
0
)))
.
call
(
x
));
}
if
(
ctx
->
y_requires_grad
)
{
const
auto
&
dy
=
ctx
->
SavedTensors
().
at
(
ctx
->
dy_index
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
sequence_function
(
functional
::
Log
)
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
dy
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
out_grads
.
at
(
0
)))
.
call
(
x
));
}
if
(
ctx
->
dz_requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
in_grads
->
at
(
2
)
=
JUST
(
functional
::
PowYGrad
(
x
,
y
,
out_grads
.
at
(
0
)));
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"pow_x_grad"
,
PowXGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"pow_y_grad"
,
PowYGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/scalar_pow.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
ScalarPowGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
x_requires_grad
=
false
;
bool
grad_requires_grad
=
false
;
Scalar
operand
;
};
class
ScalarPowGradGrad
:
public
OpExprGradFunction
<
ScalarPowGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ScalarPowGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
!
(
ctx
->
x_requires_grad
||
ctx
->
grad_requires_grad
))
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
bool
has_float_operand
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"has_float_operand"
));
if
(
has_float_operand
)
{
ctx
->
operand
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"float_operand"
)));
}
else
{
ctx
->
operand
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"int_operand"
)));
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
if
(
ctx
->
x_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ScalarPowGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
2
);
// z = x^a, dx = a * x^(a-1) * dz
// grad_for_x = out_grad * a * dz * [x^(a-1)]'
// grad_for_dz = out_grad * [x^a]'
if
(
ctx
->
x_requires_grad
)
{
const
auto
&
grad
=
ctx
->
SavedTensors
().
at
(
1
);
const
auto
operand_sub_one
=
ctx
->
operand
-
Scalar
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
std
::
bind
(
functional
::
ScalarPowGrad
,
x
,
std
::
placeholders
::
_1
,
operand_sub_one
))
.
then
([
&
ctx
](
const
std
::
shared_ptr
<
Tensor
>&
input
)
{
return
functional
::
ScalarMul
(
ctx
->
operand
,
input
);
})
.
call
(
grad
,
out_grads
.
at
(
0
)));
}
if
(
ctx
->
grad_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
ScalarPowGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
operand
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
class
ScalarReversePowGradGrad
:
public
OpExprGradFunction
<
ScalarPowGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ScalarPowGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
grad_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
!
(
ctx
->
x_requires_grad
||
ctx
->
grad_requires_grad
))
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
bool
has_float_operand
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"has_float_operand"
));
if
(
has_float_operand
)
{
ctx
->
operand
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"float_operand"
)));
}
else
{
ctx
->
operand
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"int_operand"
)));
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
if
(
ctx
->
x_requires_grad
)
{
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ScalarPowGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
2
);
// z = a^x, dx = a^x * ln(a) * dz
// grad_for_x = out_grad * dz * a^x * ln(a) * ln(a)
// grad_for_dz = out_grad * [a^x]'
if
(
ctx
->
x_requires_grad
)
{
const
auto
&
dx
=
ctx
->
SavedTensors
().
at
(
1
);
const
auto
log_operand
=
std
::
log
(
ctx
->
operand
.
As
<
double
>
());
in_grads
->
at
(
0
)
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
([
&
log_operand
](
const
std
::
shared_ptr
<
Tensor
>&
input
)
{
return
functional
::
ScalarMul
(
log_operand
,
input
);
})
.
call
(
dx
,
out_grads
.
at
(
0
)));
}
if
(
ctx
->
grad_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
ScalarReversePowGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
operand
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"scalar_pow_grad"
,
ScalarPowGradGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"scalar_reverse_pow_grad"
,
ScalarReversePowGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/slice.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
SliceGradGradCaptureState
:
public
AutoGradCaptureState
{
std
::
vector
<
int64_t
>
start
;
std
::
vector
<
int64_t
>
stop
;
std
::
vector
<
int64_t
>
step
;
};
class
SliceGradGrad
:
public
OpExprGradFunction
<
SliceGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
SliceGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
start
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"start"
));
ctx
->
stop
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"stop"
));
ctx
->
step
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"step"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
SliceGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Slice
(
out_grads
.
at
(
0
),
ctx
->
start
,
ctx
->
stop
,
ctx
->
step
,
/*enable_view_slice=*/
false
));
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"slice_grad"
,
SliceGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/smooth_l1_loss.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
SmoothL1LossGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
grad_requires_grad
=
false
;
bool
input_requires_grad
=
false
;
bool
target_requires_grad
=
false
;
size_t
grad_index
=
0
;
size_t
input_index
=
0
;
size_t
target_index
=
0
;
float
beta
=
0.0
;
};
class
SmoothL1LossGradGrad
:
public
OpExprGradFunction
<
SmoothL1LossGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
SmoothL1LossGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
// grad, input, target
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
ctx
->
grad_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
input_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
target_requires_grad
=
inputs
[
2
]
->
requires_grad
();
if
(
ctx
->
input_requires_grad
||
ctx
->
target_requires_grad
)
{
ctx
->
grad_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
}
ctx
->
input_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
ctx
->
target_index
=
ctx
->
SaveTensorForBackward
(
inputs
[
2
]);
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
beta
=
JUST
(
composed_attrs
.
GetAttr
<
float
>
(
"beta"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
SmoothL1LossGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
3
);
const
auto
&
input
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
input_index
));
const
auto
&
target
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
target_index
));
if
(
ctx
->
grad_requires_grad
)
{
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
SmoothL1LossGrad
(
out_grads
[
0
],
input
,
target
,
ctx
->
beta
));
}
if
(
ctx
->
input_requires_grad
||
ctx
->
target_requires_grad
)
{
const
auto
&
grad
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
ctx
->
grad_index
));
auto
condition
=
JUST
(
functional
::
sequence_function
(
functional
::
Sub
)
.
then
(
functional
::
Abs
)
.
then
([
&
ctx
](
const
std
::
shared_ptr
<
Tensor
>&
input
)
{
return
functional
::
ScalarLogicalLess
(
input
,
ctx
->
beta
);
})
.
call
(
input
,
target
,
/*alpha=*/
1
,
/*inplace=*/
false
));
auto
out
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
condition
))
.
then
([
&
ctx
](
const
std
::
shared_ptr
<
Tensor
>&
input
)
{
double
inv_beta
=
ctx
->
beta
==
0.0
?
0.0
:
1.0
/
ctx
->
beta
;
return
functional
::
ScalarMul
(
inv_beta
,
input
);
})
.
call
(
out_grads
[
0
],
grad
));
if
(
ctx
->
input_requires_grad
)
{
(
*
in_grads
)[
1
]
=
out
;
}
if
(
ctx
->
target_requires_grad
)
{
(
*
in_grads
)[
2
]
=
JUST
(
functional
::
Negative
(
out
));
}
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"smooth_l1_loss_grad"
,
SmoothL1LossGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/higher_order_gradient_funcs/softmax.cpp
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/sequence_function.h"
namespace
oneflow
{
namespace
one
{
struct
SoftmaxGradGradCaptureState
:
public
AutoGradCaptureState
{
bool
y_requires_grad
=
false
;
bool
dy_requires_grad
=
false
;
};
class
SoftmaxGradGrad
:
public
OpExprGradFunction
<
SoftmaxGradGradCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
SoftmaxGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
SoftmaxGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
};
Maybe
<
void
>
SoftmaxGradGrad
::
Init
(
const
OpExpr
&
op
)
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SoftmaxGradGrad
::
Capture
(
SoftmaxGradGradCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
// y, dy
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
y_requires_grad
=
inputs
[
0
]
->
requires_grad
();
ctx
->
dy_requires_grad
=
inputs
[
1
]
->
requires_grad
();
ctx
->
SaveTensorForBackward
(
inputs
[
0
]);
if
(
ctx
->
y_requires_grad
)
ctx
->
SaveTensorForBackward
(
inputs
[
1
]);
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
SoftmaxGradGrad
::
Apply
(
const
SoftmaxGradGradCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
in_grads
->
resize
(
2
);
const
auto
&
y
=
ctx
->
SavedTensors
()[
0
];
if
(
ctx
->
y_requires_grad
)
{
const
auto
&
dy
=
ctx
->
SavedTensors
()[
1
];
const
std
::
vector
<
int32_t
>
reduce_axis
{
static_cast
<
int32_t
>
(
y
->
ndim
()
-
1
)};
const
auto
&
a
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
std
::
bind
(
functional
::
ReduceSum
,
std
::
placeholders
::
_1
,
reduce_axis
,
/*keepdim=*/
true
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
dy
))
.
call
(
y
,
out_grads
[
0
]));
const
auto
&
b
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
std
::
bind
(
functional
::
ReduceSum
,
std
::
placeholders
::
_1
,
reduce_axis
,
/*keepdim=*/
true
))
.
then
(
std
::
bind
(
functional
::
Mul
,
std
::
placeholders
::
_1
,
out_grads
[
0
]))
.
call
(
y
,
dy
));
in_grads
->
at
(
0
)
=
JUST
(
functional
::
sequence_function
(
functional
::
Mul
)
.
then
(
std
::
bind
(
functional
::
Sub
,
std
::
placeholders
::
_1
,
a
,
/*alpha=*/
1
,
/*inplace=*/
false
))
.
then
(
std
::
bind
(
functional
::
Sub
,
std
::
placeholders
::
_1
,
b
,
/*alpha=*/
1
,
/*inplace=*/
false
))
.
call
(
out_grads
[
0
],
dy
));
}
if
(
ctx
->
dy_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
SoftmaxGrad
(
out_grads
[
0
],
y
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"softmax_grad"
,
SoftmaxGradGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/boxing/asymmetric_broadcast.cpp
View file @
a715222c
...
...
@@ -39,6 +39,8 @@ Maybe<void> RawCheckAsymmetricBroadcast(Symbol<PlacedNdSbp> in, Symbol<PlacedNdS
CHECK_OR_RETURN
(
NdSbpIsAllBroadcast
(
*
out
->
nd_sbp
()));
CHECK_OR_RETURN
(
out
->
placement
()
->
Bigger
(
*
in
->
placement
())
||
in
->
placement
()
->
Bigger
(
*
out
->
placement
()));
CHECK_OR_RETURN
(
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCPU
||
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCUDA
);
// NOLINTEND(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -76,16 +78,19 @@ Maybe<int64_t> CalBroadcastRoot(Symbol<ParallelDesc> src_parallel_desc,
static
constexpr
auto
*
CachedGetBroadcastRoot
=
DECORATE
(
&
CalBroadcastRoot
,
ThreadLocalCached
);
Maybe
<
one
::
UserOpExpr
>
EagerNcclBroadcast
(
Symbol
<
ParallelDesc
>
parallel_desc
,
int64_t
root
)
{
return
one
::
OpBuilder
(
"eager_nccl_broadcast"
,
*
JUST
(
UniqueStr
(
"eager_nccl_broadcast"
)))
Maybe
<
one
::
UserOpExpr
>
EagerCclBroadcast
(
Symbol
<
ParallelDesc
>
parallel_desc
,
int64_t
root
,
const
Shape
&
shape
)
{
return
one
::
OpBuilder
(
"eager_ccl_broadcast"
,
*
JUST
(
UniqueStr
(
"eager_ccl_broadcast"
)))
.
Input
(
"in"
)
.
Output
(
"out"
)
.
Attr
<
std
::
string
>
(
"parallel_conf"
,
PbMessage2TxtString
(
parallel_desc
->
parallel_conf
()))
.
Attr
<
std
::
vector
<
Shape
>>
(
"shape_list"
,
{
shape
})
.
Attr
<
int64_t
>
(
"root"
,
root
)
.
Build
();
}
static
constexpr
auto
*
CachedEagerNcclBroadcast
=
DECORATE
(
&
EagerNcclBroadcast
,
ThreadLocalCached
);
static
constexpr
auto
*
CachedEagerCclBroadcast
=
DECORATE
(
&
EagerCclBroadcast
,
ThreadLocalCachedCopiable
);
}
// namespace
Maybe
<
one
::
Tensor
>
AsymmetricBroadcast
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
,
...
...
@@ -105,26 +110,19 @@ Maybe<one::Tensor> AsymmetricBroadcast(const std::shared_ptr<one::Tensor>& tenso
if
(
out
->
placement
()
->
Bigger
(
*
in
->
placement
()))
{
const
auto
&
out_parallel_id
=
JUST
(
GetParallelId4CurrentProcessCtx
(
out_placement
));
if
(
out_parallel_id
->
has_value
())
{
const
auto
&
in_parallel_id
=
JUST
(
GetParallelId4CurrentProcessCtx
(
in_placement
));
if
(
!
in_parallel_id
->
has_value
())
{
const
std
::
string
&
device_type
=
in_placement
->
device_tag
();
local_tensor
=
JUST
(
one
::
functional
::
Empty
(
*
tensor
->
shape
(),
tensor
->
dtype
(),
JUST
(
Device
::
New
(
device_type
)),
/*pin_memory=*/
false
));
}
const
auto
&
broadcast_group
=
JUST
(
GetBroadcastGroup
(
in_placement
,
out_placement
));
Symbol
<
ParallelDesc
>
broadcast_placement_cur_rank
=
JUST
(
MapAt
(
*
broadcast_group
,
GlobalProcessCtx
::
Rank
()));
int64_t
root
=
JUST
(
CachedGetBroadcastRoot
(
in_placement
,
broadcast_placement_cur_rank
));
std
::
shared_ptr
<
one
::
UserOpExpr
>
op_expr
=
JUST
(
CachedEager
Nc
clBroadcast
(
broadcast_placement_cur_rank
,
root
));
JUST
(
CachedEager
C
clBroadcast
(
broadcast_placement_cur_rank
,
root
,
*
tensor
->
shape
()
));
local_tensor
=
JUST
(
one
::
OpInterpUtil
::
Dispatch
<
one
::
Tensor
>
(
*
op_expr
,
{
local_tensor
}));
}
}
return
one
::
functional
::
LocalTo
Consistent
(
local_tensor
,
out_placement
,
*
JUST
(
GetSbpList
(
out
->
nd_sbp
())),
*
tensor
->
shape
(),
tensor
->
dtype
());
return
one
::
functional
::
LocalTo
Global
(
local_tensor
,
out_placement
,
*
JUST
(
GetSbpList
(
out
->
nd_sbp
())),
*
tensor
->
shape
(),
tensor
->
dtype
()
,
/* sync_data */
false
,
/*copy=*/
false
);
}
COMMAND
(
RegisterBoxingFunction
(
"asymmetric-broadcast"
,
CheckAsymmetricBroadcast
,
...
...
oneflow/core/boxing/boxing_interpreter_status.h
View file @
a715222c
...
...
@@ -85,17 +85,15 @@ namespace std {
template
<
>
struct
hash
<
oneflow
::
BoxingInterpreterStatus
>
{
size_t
operator
()(
const
oneflow
::
BoxingInterpreterStatus
&
status
)
const
{
using
namespace
oneflow
;
size_t
ret
=
0
;
for
(
const
auto
&
boxing_name
:
*
status
.
sorted_boxing_names
())
{
ret
^=
std
::
hash
<
string
>
()(
boxing_name
);
}
const
auto
&
placed_nd_sbp_hash
=
std
::
hash
<
oneflow
::
PlacedNdSbp
>
();
ret
^=
placed_nd_sbp_hash
(
*
status
.
src_placed_nd_sbp
());
for
(
const
auto
&
boxing_name
:
*
status
.
sorted_boxing_names
())
{
AddHash
(
&
ret
,
boxing_name
);
}
AddHash
(
&
ret
,
*
status
.
src_placed_nd_sbp
());
for
(
const
auto
&
mid_placed_nd_sbp
:
*
status
.
mid_placed_nd_sbp
())
{
ret
^=
placed_nd_sbp_hash
(
*
mid_placed_nd_sbp
);
AddHash
(
&
ret
,
*
mid_placed_nd_sbp
);
}
ret
^=
placed_nd_sbp_hash
(
*
status
.
dst_placed_nd_sbp
());
return
hash
<
size_t
>
()(
ret
)
;
AddHash
(
&
ret
,
*
status
.
dst_placed_nd_sbp
());
return
ret
;
}
};
...
...
oneflow/core/boxing/ccl_boxing_function.cpp
View file @
a715222c
...
...
@@ -13,16 +13,55 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
namespace
oneflow
{
namespace
{
class
EagerBoxingKernelRegContext
final
:
public
user_op
::
KernelRegContext
{
public:
explicit
EagerBoxingKernelRegContext
(
DeviceType
device_type
)
:
device_type_
(
device_type
)
{}
~
EagerBoxingKernelRegContext
()
=
default
;
DeviceType
device_type
()
const
override
{
return
device_type_
;
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
const
std
::
vector
<
std
::
pair
<
std
::
string
,
int32_t
>>&
inputs
()
const
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
const
std
::
vector
<
std
::
pair
<
std
::
string
,
int32_t
>>&
outputs
()
const
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
private:
DeviceType
device_type_
;
};
Maybe
<
bool
>
RawCheckCclKernelRegistered
(
const
std
::
string
&
op_type_name
,
DeviceType
device_type
)
{
EagerBoxingKernelRegContext
reg_ctx
(
device_type
);
return
user_op
::
UserOpRegistryMgr
::
Get
().
IsOpKernelRegistered
(
op_type_name
,
reg_ctx
);
}
static
constexpr
auto
*
CheckCclKernelRegistered
=
DECORATE
(
&
RawCheckCclKernelRegistered
,
ThreadLocalCachedCopiable
);
Maybe
<
void
>
RawCheckCclP2B
(
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
,
const
Shape
&
logical_shape
)
{
// NOLINTBEGIN(maybe-need-error-msg)
...
...
@@ -33,8 +72,9 @@ Maybe<void> RawCheckCclP2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
CHECK_OR_RETURN
(
NdSbpIsAllBroadcast
(
*
out
->
nd_sbp
()));
CHECK_OR_RETURN
(
in
->
placement
()
==
out
->
placement
());
CHECK_OR_RETURN
(
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCPU
||
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCUDA
);
CHECK_OR_RETURN
(
// NOLINT
JUST
(
CheckCclKernelRegistered
(
"eager_ccl_all_reduce"
,
// NOLINT
in
->
placement
()
->
device_type
())));
// NOLINT
// NOLINTEND(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -53,8 +93,9 @@ Maybe<void> RawCheckCclP2S(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
CHECK_OR_RETURN
(
logical_shape
.
At
(
0
)
%
in
->
placement
()
->
parallel_num
()
==
0
);
CHECK_OR_RETURN
(
in
->
placement
()
==
out
->
placement
());
CHECK_OR_RETURN
(
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCPU
||
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCUDA
);
CHECK_OR_RETURN
(
// NOLINT
JUST
(
CheckCclKernelRegistered
(
"eager_ccl_reduce_scatter"
,
// NOLINT
in
->
placement
()
->
device_type
())));
// NOLINT
// NOLINTEND(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -74,8 +115,9 @@ Maybe<void> RawCheckCclS2B(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
CHECK_OR_RETURN
(
logical_shape
.
At
(
0
)
%
in
->
placement
()
->
parallel_num
()
==
0
);
CHECK_OR_RETURN
(
in
->
placement
()
==
out
->
placement
());
CHECK_OR_RETURN
(
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCPU
||
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCUDA
);
CHECK_OR_RETURN
(
// NOLINT
JUST
(
CheckCclKernelRegistered
(
"eager_ccl_all_gather"
,
// NOLINT
in
->
placement
()
->
device_type
())));
// NOLINT
// NOLINTEND(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
...
...
@@ -122,7 +164,7 @@ Maybe<one::Tensor> CclP2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<Pla
<<
Error
::
RuntimeError
()
<<
"The placement of input tensor ("
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
return
JUST
(
one
::
functional
::
Consistent
AllReduce
(
tensor
));
return
JUST
(
one
::
functional
::
Global
AllReduce
(
tensor
));
}
Maybe
<
one
::
Tensor
>
CclP2S
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
,
Symbol
<
PlacedNdSbp
>
in
,
...
...
@@ -137,7 +179,7 @@ Maybe<one::Tensor> CclP2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<Pla
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
return
JUST
(
one
::
functional
::
Consistent
ReduceScatter
(
tensor
,
"sum"
));
return
JUST
(
one
::
functional
::
Global
ReduceScatter
(
tensor
,
"sum"
));
}
Maybe
<
one
::
Tensor
>
CclS2B
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
,
Symbol
<
PlacedNdSbp
>
in
,
...
...
@@ -151,7 +193,7 @@ Maybe<one::Tensor> CclS2B(const std::shared_ptr<one::Tensor>& tensor, Symbol<Pla
<<
Error
::
RuntimeError
()
<<
"The placement of input tensor ("
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
return
JUST
(
one
::
functional
::
Consistent
AllGather
(
tensor
));
return
JUST
(
one
::
functional
::
Global
AllGather
(
tensor
));
}
Maybe
<
one
::
Tensor
>
CclS2S
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
,
Symbol
<
PlacedNdSbp
>
in
,
...
...
@@ -165,7 +207,7 @@ Maybe<one::Tensor> CclS2S(const std::shared_ptr<one::Tensor>& tensor, Symbol<Pla
<<
Error
::
RuntimeError
()
<<
"The placement of input tensor ("
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
return
JUST
(
one
::
functional
::
Consistent
S2S
(
tensor
,
*
JUST
(
GetSbpList
(
out
->
nd_sbp
()))));
return
JUST
(
one
::
functional
::
Global
S2S
(
tensor
,
*
JUST
(
GetSbpList
(
out
->
nd_sbp
()))));
}
COMMAND
(
RegisterBoxingFunction
(
"ccl-p-to-b"
,
CheckCclP2B
,
&
CclP2B
));
...
...
oneflow/core/boxing/cuda_copy_boxing_interpreter.cpp
View file @
a715222c
...
...
@@ -63,17 +63,11 @@ Maybe<one::Tensor> CopyBoxingFunction(const std::shared_ptr<one::Tensor>& tensor
<<
Error
::
RuntimeError
()
<<
"The placement of input tensor ("
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
std
::
shared_ptr
<
one
::
Tensor
>
local_tensor
=
JUST
(
tensor
->
cur_rank_phy_tensor
());
const
auto
&
out_parallel_id
=
JUST
(
GetParallelId4CurrentProcessCtx
(
out
->
placement
()));
if
(
!
out_parallel_id
->
has_value
())
{
const
std
::
string
&
device_type
=
tensor_placement
->
device_tag
();
local_tensor
=
JUST
(
one
::
functional
::
Empty
(
*
JUST
(
GetPhysicalShape
(
*
tensor
->
shape
(),
*
tensor_nd_sbp
,
*
tensor_placement
,
0
)),
tensor
->
dtype
(),
JUST
(
Device
::
New
(
device_type
)),
/*pin_memory=*/
false
));
}
const
std
::
shared_ptr
<
one
::
Tensor
>&
local_tensor
=
JUST
(
tensor
->
cur_rank_phy_tensor
());
const
auto
&
sbp_list
=
JUST
(
GetSbpList
(
out
->
nd_sbp
()));
return
JUST
(
one
::
functional
::
LocalToConsistent
(
local_tensor
,
out
->
placement
(),
*
sbp_list
,
*
tensor
->
shape
(),
tensor
->
dtype
()));
return
JUST
(
one
::
functional
::
LocalToGlobal
(
local_tensor
,
out
->
placement
(),
*
sbp_list
,
*
tensor
->
shape
(),
tensor
->
dtype
(),
/* sync_data */
false
,
/*copy=*/
false
));
}
COMMAND
(
RegisterBoxingFunction
(
"copy-h2d"
,
&
CheckCopyH2D
,
&
CopyBoxingFunction
));
...
...
oneflow/core/boxing/eager_boxing_interpreter.cpp
View file @
a715222c
...
...
@@ -38,7 +38,7 @@ Maybe<one::Tensor> EagerBoxingInterpreter::Interpret(const std::shared_ptr<one::
Symbol
<
ParallelDesc
>
in_parallel_desc
,
Symbol
<
ParallelDesc
>
out_parallel_desc
)
const
{
JUST
(
CheckEagerBoxingDataType
(
input
->
dtype
()
->
data_type
()));
DisableCheck
Consistent
TensorMetaScope
disable_meta_check
;
DisableCheck
Global
TensorMetaScope
disable_meta_check
;
const
auto
&
tensor
=
JUST
(
InterpretImpl
(
input
,
in_nd_sbp
,
out_nd_sbp
,
in_parallel_desc
,
out_parallel_desc
));
const
auto
&
tensor_nd_sbp
=
JUST
(
tensor
->
nd_sbp
());
...
...
oneflow/core/boxing/eager_boxing_interpreter_mgr.cpp
View file @
a715222c
...
...
@@ -38,6 +38,13 @@ Maybe<BoxingExprIf> OptionalCudaCopy(const std::shared_ptr<BoxingExprIf>& core_b
core_boxing_expr
,
JUST
(
OptionalBoxing
(
"copy-d2h"
))))));
}
Maybe
<
BoxingExprIf
>
OptionalCpuCopy
(
const
std
::
shared_ptr
<
BoxingExprIf
>&
core_boxing_expr
)
{
return
JUST
(
BoxingExpr
(
JUST
(
ReplaceInDeviceType
(
DeviceType
::
kCPU
)),
JUST
(
OptionalBoxing
(
"copy-d2h"
)),
JUST
(
BoxingExpr
(
JUST
(
ReplaceOutDeviceType
(
DeviceType
::
kCPU
)),
core_boxing_expr
,
JUST
(
OptionalBoxing
(
"copy-h2d"
))))));
}
Maybe
<
BoxingExprIf
>
SymmetricOneDimSxToBBoxingExpr
()
{
return
JUST
(
BoxingExpr
(
JUST
(
InPlacementAndSplit
(
0
)),
JUST
(
OptionalBoxing
(
"ccl-s-to-s"
)),
JUST
(
BoxingExpr
(
"ccl-s-to-b"
))));
...
...
@@ -152,7 +159,7 @@ Maybe<BoxingExprIf> RawMainBoxingExpr() {
|
JUST
(
SymmetricNDimToOneDimBoxingExpr
())
|
JUST
(
GenericBoxingExpr
());
// clang-format on
return
core
|
JUST
(
OptionalCudaCopy
(
core
));
return
core
|
JUST
(
OptionalCudaCopy
(
core
))
|
JUST
(
OptionalCpuCopy
(
core
))
;
}
}
// namespace
...
...
oneflow/core/boxing/flatten_hierarchy.cpp
View file @
a715222c
...
...
@@ -69,8 +69,9 @@ Maybe<one::Tensor> FlattenHierarchy(const std::shared_ptr<one::Tensor>& tensor,
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
const
auto
&
local_tensor
=
JUST
(
tensor
->
cur_rank_phy_tensor
());
const
auto
&
sbp_list
=
JUST
(
GetSbpList
(
out
->
nd_sbp
()));
return
JUST
(
one
::
functional
::
LocalToConsistent
(
local_tensor
,
out
->
placement
(),
*
sbp_list
,
*
tensor
->
shape
(),
tensor
->
dtype
()));
return
JUST
(
one
::
functional
::
LocalToGlobal
(
local_tensor
,
out
->
placement
(),
*
sbp_list
,
*
tensor
->
shape
(),
tensor
->
dtype
(),
/* sync_data */
false
,
/*copy=*/
true
));
}
COMMAND
(
RegisterBoxingFunction
(
"flatten-hierarchy"
,
CheckFlattenHierarchy
,
&
FlattenHierarchy
));
...
...
oneflow/core/boxing/generic_symmetric_nd_sbp_boxing.cpp
View file @
a715222c
...
...
@@ -163,9 +163,9 @@ Maybe<one::Tensor> GenericSymmetricNdSbpBoxing(const std::shared_ptr<one::Tensor
<<
Error
::
RuntimeError
()
<<
"Invalid input tensor, size of local tensor ("
<<
local_tensor
->
shape
()
->
ToString
()
<<
") does not match global tensor ("
<<
logical_shape
->
ToString
()
<<
")!"
;
std
::
shared_ptr
<
one
::
Tensor
>
sub_global_tensor
=
JUST
(
one
::
functional
::
LocalTo
Consistent
(
std
::
shared_ptr
<
one
::
Tensor
>
sub_global_tensor
=
JUST
(
one
::
functional
::
LocalTo
Global
(
local_tensor
,
sub_parallel_desc
,
*
JUST
(
GetSbpList
(
one_dim_nd_sbp
)),
sub_logical_shape
,
local_tensor
->
dtype
()));
local_tensor
->
dtype
()
,
/* sync_data */
false
,
/*copy=*/
false
));
sub_global_tensor
=
JUST
(
Apply1DBoxing
(
sub_global_tensor
,
one_dim_nd_sbp
,
JUST
(
SbpToNdSbp
(
broadcast_sbp
)),
...
...
@@ -175,9 +175,9 @@ Maybe<one::Tensor> GenericSymmetricNdSbpBoxing(const std::shared_ptr<one::Tensor
const
auto
&
new_nd_sbp
=
JUST
(
SetSbpAtAxis
(
*
nd_sbp
,
*
broadcast_sbp
,
i
));
output
=
JUST
(
one
::
functional
::
LocalTo
Consistent
(
lo
c
al
_tensor
,
in_parallel_desc
,
*
JUST
(
GetSbpList
(
new_nd_sbp
)),
*
logical_shape
,
local_tensor
->
dtype
()
));
output
=
JUST
(
one
::
functional
::
LocalTo
G
lo
b
al
(
local_tensor
,
in_parallel_desc
,
*
JUST
(
GetSbpList
(
new_nd_sbp
)),
*
logical_shape
,
local_tensor
->
dtype
(),
/* sync_data */
false
,
/*copy=*/
false
));
}
CHECK_OR_RETURN
(
IsAllBroadcastNdSbpAfterDim
(
JUST
(
output
->
nd_sbp
()),
first_diff_sbp_dim
))
...
...
@@ -202,9 +202,9 @@ Maybe<one::Tensor> GenericSymmetricNdSbpBoxing(const std::shared_ptr<one::Tensor
std
::
shared_ptr
<
one
::
Tensor
>
local_tensor
=
JUST
(
output
->
cur_rank_phy_tensor
());
std
::
shared_ptr
<
one
::
Tensor
>
sub_global_tensor
=
JUST
(
one
::
functional
::
LocalTo
Consistent
(
std
::
shared_ptr
<
one
::
Tensor
>
sub_global_tensor
=
JUST
(
one
::
functional
::
LocalTo
Global
(
local_tensor
,
sub_parallel_desc
,
*
JUST
(
GetSbpList
(
JUST
(
SbpToNdSbp
(
broadcast_sbp
)))),
*
sub_logical_shape
,
local_tensor
->
dtype
()));
*
sub_logical_shape
,
local_tensor
->
dtype
()
,
/* sync_data */
false
,
/*copy=*/
false
));
const
auto
&
one_dim_nd_sbp
=
JUST
(
SbpToNdSbp
(
sbp_parallel
));
sub_global_tensor
=
JUST
(
Apply1DBoxing
(
sub_global_tensor
,
JUST
(
SbpToNdSbp
(
broadcast_sbp
)),
...
...
@@ -223,18 +223,18 @@ Maybe<one::Tensor> GenericSymmetricNdSbpBoxing(const std::shared_ptr<one::Tensor
const
auto
&
new_nd_sbp
=
JUST
(
SetSbpAtAxis
(
*
nd_sbp
,
sbp_parallel
,
i
));
output
=
JUST
(
one
::
functional
::
LocalTo
Consistent
(
lo
c
al
_tensor
,
in_parallel_desc
,
*
JUST
(
GetSbpList
(
new_nd_sbp
)),
*
logical_shape
,
local_tensor
->
dtype
()
));
output
=
JUST
(
one
::
functional
::
LocalTo
G
lo
b
al
(
local_tensor
,
in_parallel_desc
,
*
JUST
(
GetSbpList
(
new_nd_sbp
)),
*
logical_shape
,
local_tensor
->
dtype
(),
/* sync_data */
false
,
/*copy=*/
false
));
// physical_shape of this axis is logical shape of next axis
sub_logical_shape
=
physical_shape
;
}
}
else
{
one
::
Consistent
TensorMeta
tensor_meta
(
input
->
shape
(),
input
->
dtype
()
->
data_type
(),
out_nd_sbp
,
out_parallel_desc
);
const
auto
&
tensor_impl
=
JUST
(
one
::
Eager
Consistent
TensorImpl
::
New
(
SymbolOf
(
tensor_meta
),
input
->
requires_grad
(),
false
));
output
=
std
::
make_shared
<
one
::
Consistent
Tensor
>
(
tensor_impl
);
one
::
Global
TensorMeta
tensor_meta
(
*
input
->
shape
(),
input
->
dtype
()
->
data_type
(),
out_nd_sbp
,
out_parallel_desc
);
const
auto
&
tensor_impl
=
JUST
(
one
::
Eager
Global
TensorImpl
::
New
(
SymbolOf
(
tensor_meta
),
input
->
requires_grad
(),
false
));
output
=
std
::
make_shared
<
one
::
Global
Tensor
>
(
tensor_impl
);
}
return
output
;
...
...
Prev
1
…
9
10
11
12
13
14
15
16
17
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment