Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
21d47d0e
Commit
21d47d0e
authored
Oct 24, 2022
by
yuguo
Browse files
Oneflow 0.8 for DCU
parents
Changes
556
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2206 additions
and
0 deletions
+2206
-0
oneflow/core/autograd/autograd_mode.cpp
oneflow/core/autograd/autograd_mode.cpp
+38
-0
oneflow/core/autograd/autograd_mode.h
oneflow/core/autograd/autograd_mode.h
+48
-0
oneflow/core/autograd/gradient_funcs/activation.cpp
oneflow/core/autograd/gradient_funcs/activation.cpp
+563
-0
oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp
oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp
+93
-0
oneflow/core/autograd/gradient_funcs/add_n.cpp
oneflow/core/autograd/gradient_funcs/add_n.cpp
+54
-0
oneflow/core/autograd/gradient_funcs/affine_grid.cpp
oneflow/core/autograd/gradient_funcs/affine_grid.cpp
+69
-0
oneflow/core/autograd/gradient_funcs/as_strided.cpp
oneflow/core/autograd/gradient_funcs/as_strided.cpp
+83
-0
oneflow/core/autograd/gradient_funcs/avg_pool.cpp
oneflow/core/autograd/gradient_funcs/avg_pool.cpp
+103
-0
oneflow/core/autograd/gradient_funcs/batch_gather.cpp
oneflow/core/autograd/gradient_funcs/batch_gather.cpp
+67
-0
oneflow/core/autograd/gradient_funcs/bias_add.cpp
oneflow/core/autograd/gradient_funcs/bias_add.cpp
+77
-0
oneflow/core/autograd/gradient_funcs/binary_cross_entropy.cpp
...low/core/autograd/gradient_funcs/binary_cross_entropy.cpp
+78
-0
oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits.cpp
...ograd/gradient_funcs/binary_cross_entropy_with_logits.cpp
+101
-0
oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp
...nt_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp
+80
-0
oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
...low/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
+352
-0
oneflow/core/autograd/gradient_funcs/broadcast_floor_mod.cpp
oneflow/core/autograd/gradient_funcs/broadcast_floor_mod.cpp
+49
-0
oneflow/core/autograd/gradient_funcs/broadcast_like.cpp
oneflow/core/autograd/gradient_funcs/broadcast_like.cpp
+74
-0
oneflow/core/autograd/gradient_funcs/cast.cpp
oneflow/core/autograd/gradient_funcs/cast.cpp
+57
-0
oneflow/core/autograd/gradient_funcs/clip_by_scalar.cpp
oneflow/core/autograd/gradient_funcs/clip_by_scalar.cpp
+75
-0
oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp
oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp
+72
-0
oneflow/core/autograd/gradient_funcs/clip_by_scalar_min.cpp
oneflow/core/autograd/gradient_funcs/clip_by_scalar_min.cpp
+73
-0
No files found.
Too many changes to show.
To preserve performance only
556 of 556+
files are displayed.
Plain diff
Email patch
oneflow/core/autograd/autograd_mode.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/autograd/autograd_mode.h"
namespace
oneflow
{
namespace
autograd
{
namespace
{
bool
*
GetThreadLocalGradMode
()
{
static
thread_local
bool
g_grad_mode
=
true
;
return
&
g_grad_mode
;
}
}
// namespace
bool
GradMode
::
is_enabled
()
{
return
*
GetThreadLocalGradMode
();
}
void
GradMode
::
set_enabled
(
bool
enabled
)
{
*
GetThreadLocalGradMode
()
=
enabled
;
}
}
// namespace autograd
}
// namespace oneflow
oneflow/core/autograd/autograd_mode.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_
namespace
oneflow
{
namespace
autograd
{
struct
GradMode
{
static
bool
is_enabled
();
static
void
set_enabled
(
bool
enabled
);
};
class
AutoGradMode
{
public:
AutoGradMode
(
bool
enabled
)
:
prev_mode_
(
GradMode
::
is_enabled
())
{
GradMode
::
set_enabled
(
enabled
);
}
~
AutoGradMode
()
{
GradMode
::
set_enabled
(
prev_mode_
);
}
bool
prev_mode
()
const
{
return
prev_mode_
;
}
private:
bool
prev_mode_
;
};
class
NoGradGuard
:
public
AutoGradMode
{
public:
NoGradGuard
()
:
AutoGradMode
(
false
){};
};
}
// namespace autograd
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_
oneflow/core/autograd/gradient_funcs/activation.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
BaseActivationCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
};
class
BaseActivation
:
public
OpExprGradFunction
<
BaseActivationCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
BaseActivationCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
ctx
->
requires_grad
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
};
class
Silu
:
public
BaseActivation
{
public:
Maybe
<
void
>
Apply
(
const
BaseActivationCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
SiluGrad
(
out_grads
.
at
(
0
),
x
));
}
return
Maybe
<
void
>::
Ok
();
}
};
class
Mish
:
public
BaseActivation
{
public:
Maybe
<
void
>
Apply
(
const
BaseActivationCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
MishGrad
(
out_grads
.
at
(
0
),
x
));
}
return
Maybe
<
void
>::
Ok
();
}
};
class
Selu
:
public
BaseActivation
{
public:
Maybe
<
void
>
Apply
(
const
BaseActivationCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
SeluGrad
(
out_grads
.
at
(
0
),
x
));
}
return
Maybe
<
void
>::
Ok
();
}
};
class
Softsign
:
public
BaseActivation
{
public:
Maybe
<
void
>
Apply
(
const
BaseActivationCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
SoftSignGrad
(
out_grads
.
at
(
0
),
x
));
}
return
Maybe
<
void
>::
Ok
();
}
};
class
GeLU
:
public
BaseActivation
{
public:
Maybe
<
void
>
Apply
(
const
BaseActivationCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
GeluGrad
(
out_grads
.
at
(
0
),
x
));
}
return
Maybe
<
void
>::
Ok
();
}
};
class
HardSigmoid
:
public
BaseActivation
{
public:
Maybe
<
void
>
Apply
(
const
BaseActivationCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
HardSigmoidGrad
(
out_grads
.
at
(
0
),
x
));
}
return
Maybe
<
void
>::
Ok
();
}
};
struct
HardShrinkCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
true
;
double
lambd
=
0.5
;
};
class
HardShrink
:
public
OpExprGradFunction
<
HardShrinkCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
HardShrinkCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
JUST
(
oneflow
::
VectorAt
(
inputs
,
0
))
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
lambd
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"lambd"
));
ctx
->
SaveTensorForBackward
(
JUST
(
oneflow
::
VectorAt
(
outputs
,
0
)));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
HardShrinkCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
y
=
JUST
(
oneflow
::
VectorAt
(
ctx
->
SavedTensors
(),
0
));
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
HardShrinkGrad
(
y
,
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
ctx
->
lambd
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
class
HardSwish
:
public
BaseActivation
{
public:
Maybe
<
void
>
Apply
(
const
BaseActivationCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
HardSwishGrad
(
out_grads
.
at
(
0
),
x
));
}
return
Maybe
<
void
>::
Ok
();
}
};
// ===== Activation with parms ====
struct
ReLUCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
};
class
ReLU
:
public
OpExprGradFunction
<
ReLUCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ReLUCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
ctx
->
requires_grad
)
{
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ReLUCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ReluGrad
(
out_grads
.
at
(
0
),
y
));
}
return
Maybe
<
void
>::
Ok
();
}
};
// ===== Activation with parms ====
struct
LeakyReluCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
float
alpha
;
};
class
LeakyRelu
:
public
OpExprGradFunction
<
LeakyReluCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
LeakyReluCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
alpha
=
JUST
(
composed_attrs
.
GetAttr
<
float
>
(
"alpha"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
LeakyReluCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
LeakyReluGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
alpha
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
SoftplusCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
true
;
double
beta
=
1.0
;
double
threshold
=
20.0
;
};
class
Softplus
:
public
OpExprGradFunction
<
SoftplusCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
SoftplusCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
beta
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"beta"
));
ctx
->
threshold
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"threshold"
));
ctx
->
SaveTensorForBackward
(
JUST
(
oneflow
::
VectorAt
(
inputs
,
0
)));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
SoftplusCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
JUST
(
oneflow
::
VectorAt
(
ctx
->
SavedTensors
(),
0
));
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
SoftplusGrad
(
x
,
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
ctx
->
beta
,
ctx
->
threshold
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
HardTanhCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
double
min_val
;
double
max_val
;
};
class
HardTanh
:
public
OpExprGradFunction
<
HardTanhCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
HardTanhCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
min_val
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"min_val"
));
ctx
->
max_val
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"max_val"
));
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
HardTanhCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
HardTanhGrad
(
y
,
out_grads
.
at
(
0
),
ctx
->
min_val
,
ctx
->
max_val
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
EluCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
double
alpha
;
};
class
Elu
:
public
OpExprGradFunction
<
EluCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
EluCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
alpha
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"alpha"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
EluCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
EluGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
alpha
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
CeluCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
true
;
double
alpha
=
1.0
;
};
class
Celu
:
public
OpExprGradFunction
<
CeluCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
CeluCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
alpha
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"alpha"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
CeluCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
CeluGrad
(
x
,
out_grads
.
at
(
0
),
ctx
->
alpha
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
SoftShrinkCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
true
;
double
alpha
=
0.5
;
};
class
SoftShrink
:
public
OpExprGradFunction
<
SoftShrinkCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
SoftShrinkCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
JUST
(
oneflow
::
VectorAt
(
inputs
,
0
))
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
alpha
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"alpha"
));
ctx
->
SaveTensorForBackward
(
JUST
(
oneflow
::
VectorAt
(
outputs
,
0
)));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
SoftShrinkCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
y
=
JUST
(
oneflow
::
VectorAt
(
ctx
->
SavedTensors
(),
0
));
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
SoftShrinkGrad
(
y
,
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
ctx
->
alpha
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
struct
PReLUCaptureState
:
public
AutoGradCaptureState
{
bool
input_requires_grad
;
bool
alpha_requires_grad
;
};
class
PReLU
:
public
OpExprGradFunction
<
PReLUCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
PReLUCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
input_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
// input
ctx
->
alpha_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
// alpha
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
PReLUCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
dy
=
out_grads
.
at
(
0
);
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
const
auto
&
alpha
=
ctx
->
SavedTensors
().
at
(
1
);
in_grads
->
resize
(
2
);
if
(
ctx
->
input_requires_grad
||
ctx
->
alpha_requires_grad
)
{
const
auto
&
grads
=
JUST
(
functional
::
PReluGrad
(
dy
,
x
,
alpha
));
if
(
ctx
->
input_requires_grad
)
{
in_grads
->
at
(
0
)
=
grads
->
at
(
0
);
}
if
(
ctx
->
alpha_requires_grad
)
{
in_grads
->
at
(
1
)
=
grads
->
at
(
1
);
}
}
return
Maybe
<
void
>::
Ok
();
}
private:
std
::
shared_ptr
<
OpExpr
>
grad_op_
;
};
struct
ThresholdCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
true
;
double
threshold
=
0.0
;
};
class
Threshold
:
public
OpExprGradFunction
<
ThresholdCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ThresholdCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
JUST
(
oneflow
::
VectorAt
(
inputs
,
0
))
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
threshold
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"threshold_val"
));
ctx
->
SaveTensorForBackward
(
JUST
(
oneflow
::
VectorAt
(
inputs
,
0
)));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ThresholdCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
JUST
(
oneflow
::
VectorAt
(
ctx
->
SavedTensors
(),
0
));
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
ThresholdGrad
(
x
,
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
ctx
->
threshold
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"silu"
,
Silu
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"mish"
,
Mish
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"selu"
,
Selu
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"softsign"
,
Softsign
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"relu"
,
ReLU
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"gelu"
,
GeLU
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"hardsigmoid"
,
HardSigmoid
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"hardshrink"
,
HardShrink
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"hardswish"
,
HardSwish
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"leaky_relu"
,
LeakyRelu
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"hardtanh"
,
HardTanh
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"elu"
,
Elu
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"celu"
,
Celu
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"prelu"
,
PReLU
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"threshold"
,
Threshold
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"softplus"
,
Softplus
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"softshrink"
,
SoftShrink
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
AdaptivePoolCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
};
class
AdaptivePoolNdGrad
:
public
OpExprGradFunction
<
AdaptivePoolCaptureState
>
{
public:
using
OpExprGradFunction
<
AdaptivePoolCaptureState
>::
Init
;
Maybe
<
void
>
Init
(
const
OpExpr
&
op
,
std
::
string
mode
,
const
int
&
ndims
);
Maybe
<
void
>
Capture
(
AdaptivePoolCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
AdaptivePoolCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
std
::
string
mode_
;
int32_t
ndims_
;
};
Maybe
<
void
>
AdaptivePoolNdGrad
::
Init
(
const
OpExpr
&
op
,
std
::
string
mode
,
const
int
&
ndims
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
mode_
=
mode
;
ndims_
=
ndims
;
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
AdaptivePoolNdGrad
::
Capture
(
AdaptivePoolCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
AdaptivePoolNdGrad
::
Apply
(
const
AdaptivePoolCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
AdaptivePoolNdGrad
(
x
,
out_grads
.
at
(
0
),
mode_
,
ndims_
));
return
Maybe
<
void
>::
Ok
();
}
class
AdaptiveAvgPool1dGrad
final
:
public
AdaptivePoolNdGrad
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
AdaptivePoolNdGrad
::
Init
(
op
,
"avg"
,
1
);
}
};
class
AdaptiveAvgPool2dGrad
final
:
public
AdaptivePoolNdGrad
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
AdaptivePoolNdGrad
::
Init
(
op
,
"avg"
,
2
);
}
};
class
AdaptiveAvgPool3dGrad
final
:
public
AdaptivePoolNdGrad
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
AdaptivePoolNdGrad
::
Init
(
op
,
"avg"
,
3
);
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"adaptive_avg_pool1d"
,
AdaptiveAvgPool1dGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"adaptive_avg_pool2d"
,
AdaptiveAvgPool2dGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"adaptive_avg_pool3d"
,
AdaptiveAvgPool3dGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/add_n.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace
oneflow
{
namespace
one
{
struct
AddNCaptureState
:
public
AutoGradCaptureState
{
int32_t
input_num
;
std
::
vector
<
bool
>
requires_grad
;
};
class
AddN
:
public
OpExprGradFunction
<
AddNCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
AddNCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
ctx
->
input_num
=
inputs
.
size
();
ctx
->
requires_grad
.
resize
(
inputs
.
size
());
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
ctx
->
requires_grad
[
i
]
=
inputs
.
at
(
i
)
->
requires_grad
();
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
AddNCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
ctx
->
input_num
);
for
(
int
i
=
0
;
i
<
ctx
->
input_num
;
++
i
)
{
if
(
ctx
->
requires_grad
.
at
(
i
))
{
in_grads
->
at
(
i
)
=
out_grads
.
at
(
0
);
}
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"add_n"
,
AddN
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/affine_grid.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
AffineGridInterpState
:
public
AutoGradCaptureState
{
Shape
size
;
bool
align_corners
=
false
;
bool
requires_grad
=
false
;
};
class
AffineGrid
:
public
OpExprGradFunction
<
AffineGridInterpState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
AffineGridInterpState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
// theta
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
size
=
JUST
(
composed_attrs
.
GetAttr
<
Shape
>
(
"size"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
AffineGridInterpState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
AffineGridGrad
(
out_grads
.
at
(
0
),
ctx
->
size
,
ctx
->
align_corners
));
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"affine_grid"
,
AffineGrid
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/as_strided.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
AsStridedCaptureState
:
public
AutoGradCaptureState
{
std
::
vector
<
int32_t
>
size
;
std
::
vector
<
int32_t
>
stride
;
int32_t
storage_offset
=
0
;
bool
requires_grad
=
false
;
};
class
AsStrided
:
public
OpExprGradFunction
<
AsStridedCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
AsStridedCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
AsStridedCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
AsStrided
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
AsStrided
::
Capture
(
AsStridedCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"size"
));
ctx
->
stride
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"stride"
));
ctx
->
storage_offset
=
JUST
(
composed_attrs
.
GetAttr
<
int32_t
>
(
"storage_offset"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
AsStrided
::
Apply
(
const
AsStridedCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
input
=
ctx
->
SavedTensors
().
at
(
0
);
std
::
vector
<
int32_t
>
size
=
ctx
->
size
;
std
::
vector
<
int32_t
>
stride
=
ctx
->
stride
;
int32_t
storage_offset
=
ctx
->
storage_offset
;
in_grads
->
at
(
0
)
=
JUST
(
functional
::
AsStridedGrad
(
out_grads
.
at
(
0
),
input
,
size
,
stride
,
storage_offset
));
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"as_strided"
,
AsStrided
);
}
// namespace one
}
// namespace oneflow
\ No newline at end of file
oneflow/core/autograd/gradient_funcs/avg_pool.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
namespace
{
struct
AvgPoolCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
size_t
input_index
=
0
;
std
::
string
data_format
;
std
::
vector
<
int32_t
>
padding
;
std
::
vector
<
int32_t
>
kernel_size
;
std
::
vector
<
int32_t
>
stride
;
bool
ceil_mode
=
false
;
bool
count_include_pad
=
false
;
int32_t
divisor_override
=
0
;
};
class
AvgPoolNdGrad
:
public
OpExprGradFunction
<
AvgPoolCaptureState
>
{
public:
virtual
~
AvgPoolNdGrad
()
=
default
;
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
AvgPoolCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
AvgPoolCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
AvgPoolNdGrad
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
AvgPoolNdGrad
::
Capture
(
AvgPoolCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
input_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
padding
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"padding"
));
ctx
->
kernel_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"kernel_size"
));
ctx
->
stride
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"stride"
));
ctx
->
ceil_mode
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"ceil_mode"
));
ctx
->
count_include_pad
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"count_include_pad"
));
ctx
->
divisor_override
=
JUST
(
composed_attrs
.
GetAttr
<
int32_t
>
(
"divisor_override"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
AvgPoolNdGrad
::
Apply
(
const
AvgPoolCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
int32_t
ndims
=
ctx
->
kernel_size
.
size
();
const
auto
&
input
=
ctx
->
SavedTensors
().
at
(
ctx
->
input_index
);
in_grads
->
resize
(
1
);
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
AvgPoolNdGrad
(
input
,
out_grads
[
0
],
ndims
,
ctx
->
data_format
,
ctx
->
padding
,
ctx
->
kernel_size
,
ctx
->
stride
,
ctx
->
ceil_mode
,
ctx
->
count_include_pad
,
ctx
->
divisor_override
));
return
Maybe
<
void
>::
Ok
();
}
}
// namespace
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"avg_pool_1d"
,
AvgPoolNdGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"avg_pool_2d"
,
AvgPoolNdGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"avg_pool_3d"
,
AvgPoolNdGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/batch_gather.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
BatchGatherCaptureState
:
public
AutoGradCaptureState
{
int64_t
num_segments
;
bool
requires_grad
;
};
class
BatchGather
:
public
OpExprGradFunction
<
BatchGatherCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
BatchGatherCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
BatchGatherCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
};
Maybe
<
void
>
BatchGather
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BatchGather
::
Capture
(
BatchGatherCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
const
auto
&
in_shape
=
inputs
.
at
(
0
)
->
shape
();
const
auto
&
indices_shape
=
inputs
.
at
(
1
)
->
shape
();
ctx
->
num_segments
=
in_shape
->
At
(
indices_shape
->
NumAxes
()
-
1
);
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BatchGather
::
Apply
(
const
BatchGatherCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
in_grads
->
resize
(
2
);
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
const
auto
&
indices
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
UnsortedBatchSegmentSum
(
out_grads
.
at
(
0
),
indices
,
ctx
->
num_segments
));
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"batch_gather"
,
BatchGather
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/bias_add.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
BiasAddCaptureState
:
public
AutoGradCaptureState
{
bool
input_requires_grad
;
bool
bias_requires_grad
;
int32_t
axis
;
};
class
BiasAdd
:
public
OpExprGradFunction
<
BiasAddCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
BiasAddCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
input_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
bias_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
axis
=
JUST
(
composed_attrs
.
GetAttr
<
int32_t
>
(
"axis"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
BiasAddCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
const
int64_t
num_axes
=
out_grads
.
at
(
0
)
->
shape
()
->
NumAxes
();
in_grads
->
resize
(
2
);
if
(
ctx
->
bias_requires_grad
)
{
std
::
vector
<
int32_t
>
reduce_axes_vec
;
reduce_axes_vec
.
reserve
(
num_axes
);
for
(
int
i
=
0
;
i
<
num_axes
;
++
i
)
{
if
(
i
!=
ctx
->
axis
)
{
reduce_axes_vec
.
emplace_back
(
i
);
}
}
if
(
ctx
->
bias_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
ReduceSum
(
out_grads
.
at
(
0
),
reduce_axes_vec
,
false
));
}
}
if
(
ctx
->
input_requires_grad
)
{
in_grads
->
at
(
0
)
=
out_grads
.
at
(
0
);
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"bias_add"
,
BiasAdd
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/binary_cross_entropy.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
BinaryCrossEntropyCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
};
class
BinaryCrossEntropy
:
public
OpExprGradFunction
<
BinaryCrossEntropyCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
BinaryCrossEntropyCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
BinaryCrossEntropyCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
BinaryCrossEntropy
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropy
::
Capture
(
BinaryCrossEntropyCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
// input
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
// target
if
(
inputs
.
size
()
==
3
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
));
// weight
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropy
::
Apply
(
const
BinaryCrossEntropyCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
dy
=
out_grads
.
at
(
0
);
const
auto
&
input
=
ctx
->
SavedTensors
().
at
(
0
);
const
auto
&
target
=
ctx
->
SavedTensors
().
at
(
1
);
in_grads
->
resize
(
ctx
->
SavedTensors
().
size
());
if
(
ctx
->
SavedTensors
().
size
()
==
3
)
{
const
auto
&
weight
=
ctx
->
SavedTensors
().
at
(
2
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyLossGrad
(
dy
,
input
,
target
,
weight
));
}
else
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyLossGrad
(
dy
,
input
,
target
,
NullOpt
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"binary_cross_entropy"
,
BinaryCrossEntropy
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
BinaryCrossEntropyWithLogitsCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
bool
has_pos_weight
=
false
;
};
class
BinaryCrossEntropyWithLogits
:
public
OpExprGradFunction
<
BinaryCrossEntropyWithLogitsCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
BinaryCrossEntropyWithLogitsCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
BinaryCrossEntropyWithLogitsCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
BinaryCrossEntropyWithLogits
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyWithLogits
::
Capture
(
BinaryCrossEntropyWithLogitsCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
has_pos_weight
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"has_pos_weight"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
// input
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
// target
if
(
inputs
.
size
()
==
3
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
));
// weight or pos_weight
}
if
(
inputs
.
size
()
==
4
)
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
));
// weight
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
3
));
// pos_weight
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyWithLogits
::
Apply
(
const
BinaryCrossEntropyWithLogitsCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
dy
=
out_grads
.
at
(
0
);
const
auto
&
input
=
ctx
->
SavedTensors
().
at
(
0
);
const
auto
&
target
=
ctx
->
SavedTensors
().
at
(
1
);
in_grads
->
resize
(
ctx
->
SavedTensors
().
size
());
if
(
ctx
->
SavedTensors
().
size
()
==
3
)
{
if
(
ctx
->
has_pos_weight
)
{
const
auto
&
pos_weight
=
ctx
->
SavedTensors
().
at
(
2
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsLossGrad
(
dy
,
input
,
target
,
NullOpt
,
pos_weight
));
}
else
{
const
auto
&
weight
=
ctx
->
SavedTensors
().
at
(
2
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsLossGrad
(
dy
,
input
,
target
,
weight
,
NullOpt
));
}
}
else
if
(
ctx
->
SavedTensors
().
size
()
==
4
)
{
const
auto
&
weight
=
ctx
->
SavedTensors
().
at
(
2
);
const
auto
&
pos_weight
=
ctx
->
SavedTensors
().
at
(
3
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsLossGrad
(
dy
,
input
,
target
,
weight
,
pos_weight
));
}
else
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsLossGrad
(
dy
,
input
,
target
,
NullOpt
,
NullOpt
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"binary_cross_entropy_with_logits"
,
BinaryCrossEntropyWithLogits
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/binary_cross_entropy_with_logits_reduce_mean.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
BinaryCrossEntropyWithLogitsReduceMeanCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
bool
has_pos_weight
=
false
;
};
class
BinaryCrossEntropyWithLogitsReduceMean
:
public
OpExprGradFunction
<
BinaryCrossEntropyWithLogitsReduceMeanCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
BinaryCrossEntropyWithLogitsReduceMeanCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
BinaryCrossEntropyWithLogitsReduceMeanCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
BinaryCrossEntropyWithLogitsReduceMean
::
Init
(
const
OpExpr
&
op
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
)
<<
"fw_op_expr should not be null. "
;
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyWithLogitsReduceMean
::
Capture
(
BinaryCrossEntropyWithLogitsReduceMeanCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
JUST
(
VectorAt
(
inputs
,
0
))
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
0
)));
// input
ctx
->
SaveTensorForBackward
(
JUST
(
VectorAt
(
inputs
,
1
)));
// target
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BinaryCrossEntropyWithLogitsReduceMean
::
Apply
(
const
BinaryCrossEntropyWithLogitsReduceMeanCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
)
<<
"out_grads size should be equal to 1. "
;
const
auto
&
dy
=
JUST
(
VectorAt
(
out_grads
,
0
));
const
auto
&
input
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
0
));
const
auto
&
target
=
JUST
(
VectorAt
(
ctx
->
SavedTensors
(),
1
));
in_grads
->
resize
(
ctx
->
SavedTensors
().
size
());
JUST
(
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
BinaryCrossEntropyWithLogitsReduceMeanLossGrad
(
dy
,
input
,
target
));
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"binary_cross_entropy_with_logits_reduce_mean"
,
BinaryCrossEntropyWithLogitsReduceMean
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
BroadcastBinaryCaptureState
:
public
AutoGradCaptureState
{
int
x_index
=
-
1
;
int
y_index
=
-
1
;
int
z_index
=
-
1
;
bool
x_requires_grad
=
false
;
bool
y_requires_grad
=
false
;
bool
broadcast_x
=
false
;
bool
broadcast_y
=
false
;
};
class
BroadcastBinaryGrad
:
public
OpExprGradFunction
<
BroadcastBinaryCaptureState
>
{
public:
BroadcastBinaryGrad
()
=
default
;
virtual
~
BroadcastBinaryGrad
()
=
default
;
virtual
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
x_requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
ctx
->
y_requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
ctx
->
broadcast_x
=
(
*
inputs
.
at
(
0
)
->
shape
()
!=
*
outputs
.
at
(
0
)
->
shape
());
ctx
->
broadcast_y
=
(
*
inputs
.
at
(
1
)
->
shape
()
!=
*
outputs
.
at
(
0
)
->
shape
());
return
SaveTensorForBackward
(
ctx
,
inputs
,
outputs
);
}
protected:
virtual
Maybe
<
void
>
SaveTensorForBackward
(
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
const
=
0
;
};
class
BroadcastAdd
:
public
BroadcastBinaryGrad
{
public:
Maybe
<
void
>
Apply
(
const
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
)
{
if
(
ctx
->
broadcast_x
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
out_grads
.
at
(
0
),
x
));
}
else
{
in_grads
->
at
(
0
)
=
out_grads
.
at
(
0
);
}
}
if
(
ctx
->
y_requires_grad
)
{
if
(
ctx
->
broadcast_y
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
out_grads
.
at
(
0
),
y
));
}
else
{
in_grads
->
at
(
1
)
=
out_grads
.
at
(
0
);
}
}
return
Maybe
<
void
>::
Ok
();
}
protected:
Maybe
<
void
>
SaveTensorForBackward
(
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
const
override
{
if
(
ctx
->
x_requires_grad
&&
ctx
->
broadcast_x
)
{
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
if
(
ctx
->
y_requires_grad
&&
ctx
->
broadcast_y
)
{
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_add"
,
BroadcastAdd
);
class
BroadcastSub
:
public
BroadcastBinaryGrad
{
public:
Maybe
<
void
>
Apply
(
const
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
)
{
if
(
ctx
->
broadcast_x
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
out_grads
.
at
(
0
),
x
));
}
else
{
in_grads
->
at
(
0
)
=
out_grads
.
at
(
0
);
}
}
if
(
ctx
->
y_requires_grad
)
{
const
auto
&
grad
=
JUST
(
functional
::
ScalarMul
(
out_grads
.
at
(
0
),
Scalar
(
-
1.
f
),
false
));
if
(
ctx
->
broadcast_y
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
grad
,
y
));
}
else
{
in_grads
->
at
(
1
)
=
grad
;
}
}
return
Maybe
<
void
>::
Ok
();
}
protected:
Maybe
<
void
>
SaveTensorForBackward
(
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
const
override
{
if
(
ctx
->
x_requires_grad
&&
ctx
->
broadcast_x
)
{
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
if
(
ctx
->
y_requires_grad
&&
ctx
->
broadcast_y
)
{
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_sub"
,
BroadcastSub
);
class
BroadcastMul
:
public
BroadcastBinaryGrad
{
public:
Maybe
<
void
>
Apply
(
const
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
const
auto
&
x_grad
=
JUST
(
functional
::
Mul
(
out_grads
.
at
(
0
),
y
));
if
(
ctx
->
broadcast_x
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
x_grad
,
x
));
}
else
{
in_grads
->
at
(
0
)
=
x_grad
;
}
}
if
(
ctx
->
y_requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
const
auto
&
y_grad
=
JUST
(
functional
::
Mul
(
out_grads
.
at
(
0
),
x
));
if
(
ctx
->
broadcast_y
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
y_grad
,
y
));
}
else
{
in_grads
->
at
(
1
)
=
y_grad
;
}
}
return
Maybe
<
void
>::
Ok
();
}
protected:
Maybe
<
void
>
SaveTensorForBackward
(
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
const
override
{
if
(
ctx
->
x_requires_grad
)
{
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
if
(
ctx
->
broadcast_x
)
{
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
}
if
(
ctx
->
y_requires_grad
)
{
if
(
ctx
->
x_index
==
-
1
/*x has not been saved*/
)
{
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
if
(
ctx
->
broadcast_y
&&
ctx
->
y_index
==
-
1
/*y has not been saved*/
)
{
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_mul"
,
BroadcastMul
);
class
BroadcastDiv
:
public
BroadcastBinaryGrad
{
public:
Maybe
<
void
>
Apply
(
const
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
const
auto
&
x_grad
=
JUST
(
functional
::
Div
(
out_grads
.
at
(
0
),
y
));
if
(
ctx
->
broadcast_x
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
x_grad
,
x
));
}
else
{
in_grads
->
at
(
0
)
=
x_grad
;
}
}
if
(
ctx
->
y_requires_grad
)
{
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
const
auto
&
z
=
ctx
->
SavedTensors
().
at
(
ctx
->
z_index
);
in_grads
->
at
(
1
)
=
JUST
(
functional
::
DivGrad
(
out_grads
.
at
(
0
),
z
,
y
));
}
return
Maybe
<
void
>::
Ok
();
}
protected:
Maybe
<
void
>
SaveTensorForBackward
(
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
const
override
{
if
(
ctx
->
x_requires_grad
)
{
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
if
(
ctx
->
broadcast_x
)
{
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
}
if
(
ctx
->
y_requires_grad
)
{
if
(
ctx
->
y_index
==
-
1
/*y has not been saved*/
)
{
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
ctx
->
z_index
=
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
0
));
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_div"
,
BroadcastDiv
);
class
BroadcastPow
:
public
BroadcastBinaryGrad
{
public:
Maybe
<
void
>
Apply
(
const
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
const
auto
&
z
=
ctx
->
SavedTensors
().
at
(
ctx
->
z_index
);
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BroadcastPowXGrad
(
out_grads
.
at
(
0
),
x
,
y
,
z
));
}
if
(
ctx
->
y_requires_grad
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
BroadcastPowYGrad
(
out_grads
.
at
(
0
),
x
,
y
,
z
));
}
return
Maybe
<
void
>::
Ok
();
}
protected:
Maybe
<
void
>
SaveTensorForBackward
(
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
const
override
{
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
ctx
->
z_index
=
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_pow"
,
BroadcastPow
);
class
BroadcastMinMax
:
public
BroadcastBinaryGrad
{
public:
Maybe
<
void
>
Apply
(
const
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
const
auto
&
out_shape
=
*
(
out_grads
.
at
(
0
)
->
shape
());
in_grads
->
resize
(
2
);
if
(
ctx
->
x_requires_grad
||
ctx
->
y_requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
x_index
);
const
auto
&
y
=
ctx
->
SavedTensors
().
at
(
ctx
->
y_index
);
auto
broad_x_
=
x
;
auto
broad_y_
=
y
;
if
(
ctx
->
broadcast_x
)
{
const
auto
&
x_shape
=
*
(
x
->
shape
());
const
Shape
&
left_extended_x_shape
=
CreateLeftExtendedShape
(
ShapeView
(
x_shape
),
out_shape
.
NumAxes
());
if
(
left_extended_x_shape
==
out_shape
)
{
broad_x_
=
JUST
(
functional
::
ReshapeLike
(
x
,
JUST
(
VectorAt
(
out_grads
,
0
))));
}
else
{
const
AxisVector
&
broadcast_axis_vec
=
left_extended_x_shape
.
Axes4BroadcastTo
(
out_shape
);
const
std
::
vector
<
int32_t
>
x_axis
=
std
::
vector
<
int32_t
>
{
broadcast_axis_vec
.
begin
(),
broadcast_axis_vec
.
end
()};
broad_x_
=
JUST
(
functional
::
BroadcastLike
(
x
,
JUST
(
VectorAt
(
out_grads
,
0
)),
x_axis
));
}
}
if
(
ctx
->
broadcast_y
)
{
const
auto
&
y_shape
=
*
(
y
->
shape
());
const
Shape
&
left_extended_y_shape
=
CreateLeftExtendedShape
(
ShapeView
(
y_shape
),
out_shape
.
NumAxes
());
if
(
left_extended_y_shape
==
out_shape
)
{
broad_y_
=
JUST
(
functional
::
ReshapeLike
(
y
,
JUST
(
VectorAt
(
out_grads
,
0
))));
}
else
{
const
AxisVector
&
broadcast_axis_vec
=
left_extended_y_shape
.
Axes4BroadcastTo
(
out_shape
);
const
std
::
vector
<
int32_t
>
y_axis
=
std
::
vector
<
int32_t
>
{
broadcast_axis_vec
.
begin
(),
broadcast_axis_vec
.
end
()};
broad_y_
=
JUST
(
functional
::
BroadcastLike
(
y
,
JUST
(
VectorAt
(
out_grads
,
0
)),
y_axis
));
}
}
const
auto
&
broad_grads
=
JUST
(
elementwise_grad_functor_
(
out_grads
.
at
(
0
),
broad_x_
,
broad_y_
));
if
(
ctx
->
x_requires_grad
)
{
if
(
ctx
->
broadcast_x
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
broad_grads
->
at
(
0
),
x
));
}
else
{
in_grads
->
at
(
0
)
=
broad_grads
->
at
(
0
);
}
}
if
(
ctx
->
y_requires_grad
)
{
if
(
ctx
->
broadcast_y
)
{
in_grads
->
at
(
1
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
broad_grads
->
at
(
1
),
y
));
}
else
{
in_grads
->
at
(
1
)
=
broad_grads
->
at
(
1
);
}
}
}
return
Maybe
<
void
>::
Ok
();
}
protected:
Maybe
<
void
>
SaveTensorForBackward
(
BroadcastBinaryCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
const
override
{
if
(
ctx
->
x_requires_grad
||
ctx
->
y_requires_grad
)
{
ctx
->
x_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ctx
->
y_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
}
return
Maybe
<
void
>::
Ok
();
}
std
::
function
<
Maybe
<
TensorTuple
>
(
const
std
::
shared_ptr
<
Tensor
>&
,
const
std
::
shared_ptr
<
Tensor
>&
,
const
std
::
shared_ptr
<
Tensor
>&
)
>
elementwise_grad_functor_
;
};
class
BroadcastMinimum
:
public
BroadcastMinMax
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
JUST
(
BroadcastMinMax
::
Init
(
op
));
elementwise_grad_functor_
=
functional
::
ElementwiseMinGrad
;
return
Maybe
<
void
>::
Ok
();
}
};
class
BroadcastMaximum
:
public
BroadcastMinMax
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
JUST
(
BroadcastMinMax
::
Init
(
op
));
elementwise_grad_functor_
=
functional
::
ElementwiseMaxGrad
;
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_minimum"
,
BroadcastMinimum
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_maximum"
,
BroadcastMaximum
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/broadcast_floor_mod.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace
oneflow
{
namespace
one
{
struct
BroadcastFModCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
};
class
BroadcastFMod
:
public
OpExprGradFunction
<
BroadcastFModCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
BroadcastFModCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
BroadcastFModCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
2
);
if
(
ctx
->
requires_grad
)
{
in_grads
->
at
(
0
)
=
out_grads
.
at
(
0
);
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_fmod"
,
BroadcastFMod
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/broadcast_like.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
BroadCastLikeCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
size_t
input_index
;
std
::
vector
<
int32_t
>
broadcast_axes
;
};
class
BroadCastLike
:
public
OpExprGradFunction
<
BroadCastLikeCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
BroadCastLikeCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
BroadCastLikeCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
BroadCastLike
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BroadCastLike
::
Capture
(
BroadCastLikeCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
broadcast_axes
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"broadcast_axes"
));
ctx
->
input_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
BroadCastLike
::
Apply
(
const
BroadCastLikeCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
ctx
->
input_index
);
in_grads
->
resize
(
2
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ReduceSumLike
(
out_grads
.
at
(
0
),
x
,
ctx
->
broadcast_axes
));
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"broadcast_like"
,
BroadCastLike
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/cast.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/symbol.h"
namespace
oneflow
{
namespace
one
{
struct
CastCaptureState
:
public
AutoGradCaptureState
{
Symbol
<
DType
>
dtype
;
};
class
Cast
:
public
OpExprGradFunction
<
CastCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
CastCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
ctx
->
dtype
=
inputs
.
at
(
0
)
->
dtype
();
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
CastCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
in_grads
->
resize
(
1
);
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
Cast
(
out_grads
[
0
],
ctx
->
dtype
,
/*pin_memory=*/
false
));
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"cast"
,
Cast
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/clip_by_scalar.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
ClipByScalarCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
Scalar
min
;
Scalar
max
;
};
class
ClipByScalar
:
public
OpExprGradFunction
<
ClipByScalarCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ClipByScalarCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
if
(
IsFloatingDataType
(
inputs
.
at
(
0
)
->
dtype
()
->
data_type
()))
{
ctx
->
min
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"floating_min"
)));
ctx
->
max
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"floating_max"
)));
}
else
if
(
IsIntegralDataType
(
inputs
.
at
(
0
)
->
dtype
()
->
data_type
()))
{
ctx
->
min
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"integral_min"
)));
ctx
->
max
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"integral_max"
)));
}
else
{
UNIMPLEMENTED_THEN_RETURN
()
<<
"Data type is not floating or integral type."
;
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ClipByScalarCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ClampGrad
(
out_grads
.
at
(
0
),
x
,
ctx
->
min
,
ctx
->
max
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"clip_by_scalar"
,
ClipByScalar
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
ClipByScalarMaxCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
Scalar
max
;
};
class
ClipByScalarMax
:
public
OpExprGradFunction
<
ClipByScalarMaxCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ClipByScalarMaxCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
if
(
IsFloatingDataType
(
inputs
.
at
(
0
)
->
dtype
()
->
data_type
()))
{
ctx
->
max
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"floating_max"
)));
}
else
if
(
IsIntegralDataType
(
inputs
.
at
(
0
)
->
dtype
()
->
data_type
()))
{
ctx
->
max
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"integral_max"
)));
}
else
{
UNIMPLEMENTED_THEN_RETURN
()
<<
"Data type is not floating or integral type."
;
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ClipByScalarMaxCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ClampGrad
(
out_grads
.
at
(
0
),
x
,
/*min=*/
NullOpt
,
ctx
->
max
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"clip_by_scalar_max"
,
ClipByScalarMax
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/clip_by_scalar_min.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
ClipByScalarMinCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
Scalar
min
;
};
class
ClipByScalarMin
:
public
OpExprGradFunction
<
ClipByScalarMinCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ClipByScalarMinCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
if
(
IsFloatingDataType
(
inputs
.
at
(
0
)
->
dtype
()
->
data_type
()))
{
ctx
->
min
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"floating_min"
)));
}
else
if
(
IsIntegralDataType
(
inputs
.
at
(
0
)
->
dtype
()
->
data_type
()))
{
ctx
->
min
=
Scalar
(
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"integral_min"
)));
}
else
{
UNIMPLEMENTED_THEN_RETURN
()
<<
"Data type is not floating or integral type."
;
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ClipByScalarMinCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
const
auto
&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ClampGrad
(
out_grads
.
at
(
0
),
x
,
ctx
->
min
,
/*max=*/
NullOpt
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"clip_by_scalar_min"
,
ClipByScalarMin
);
}
// namespace one
}
// namespace oneflow
Prev
1
…
13
14
15
16
17
18
19
20
21
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment