Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
21d47d0e
Commit
21d47d0e
authored
Oct 24, 2022
by
yuguo
Browse files
Oneflow 0.8 for DCU
parents
Changes
556
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2465 additions
and
0 deletions
+2465
-0
oneflow/core/autograd/gradient_funcs/tf_pool.cpp
oneflow/core/autograd/gradient_funcs/tf_pool.cpp
+124
-0
oneflow/core/autograd/gradient_funcs/to_contiguous.cpp
oneflow/core/autograd/gradient_funcs/to_contiguous.cpp
+48
-0
oneflow/core/autograd/gradient_funcs/transpose.cpp
oneflow/core/autograd/gradient_funcs/transpose.cpp
+73
-0
oneflow/core/autograd/gradient_funcs/tril.cpp
oneflow/core/autograd/gradient_funcs/tril.cpp
+69
-0
oneflow/core/autograd/gradient_funcs/triu.cpp
oneflow/core/autograd/gradient_funcs/triu.cpp
+69
-0
oneflow/core/autograd/gradient_funcs/two_stage_reduce.cpp
oneflow/core/autograd/gradient_funcs/two_stage_reduce.cpp
+142
-0
oneflow/core/autograd/gradient_funcs/unfold.cpp
oneflow/core/autograd/gradient_funcs/unfold.cpp
+84
-0
oneflow/core/autograd/gradient_funcs/unfold_tensor.cpp
oneflow/core/autograd/gradient_funcs/unfold_tensor.cpp
+76
-0
oneflow/core/autograd/gradient_funcs/unsqueeze.cpp
oneflow/core/autograd/gradient_funcs/unsqueeze.cpp
+80
-0
oneflow/core/autograd/gradient_funcs/upsample.cpp
oneflow/core/autograd/gradient_funcs/upsample.cpp
+434
-0
oneflow/core/autograd/gradient_funcs/variance.cpp
oneflow/core/autograd/gradient_funcs/variance.cpp
+109
-0
oneflow/core/autograd/gradient_funcs/where.cpp
oneflow/core/autograd/gradient_funcs/where.cpp
+130
-0
oneflow/core/boxing/asymmetric_broadcast.cpp
oneflow/core/boxing/asymmetric_broadcast.cpp
+133
-0
oneflow/core/boxing/boxing_dividor.h
oneflow/core/boxing/boxing_dividor.h
+52
-0
oneflow/core/boxing/boxing_dividor_util.cpp
oneflow/core/boxing/boxing_dividor_util.cpp
+310
-0
oneflow/core/boxing/boxing_dividor_util.h
oneflow/core/boxing/boxing_dividor_util.h
+40
-0
oneflow/core/boxing/boxing_interpreter_status.cpp
oneflow/core/boxing/boxing_interpreter_status.cpp
+130
-0
oneflow/core/boxing/boxing_interpreter_status.h
oneflow/core/boxing/boxing_interpreter_status.h
+104
-0
oneflow/core/boxing/ccl_boxing_function.cpp
oneflow/core/boxing/ccl_boxing_function.cpp
+176
-0
oneflow/core/boxing/cuda_copy_boxing_interpreter.cpp
oneflow/core/boxing/cuda_copy_boxing_interpreter.cpp
+82
-0
No files found.
Too many changes to show.
To preserve performance only
556 of 556+
files are displayed.
Plain diff
Email patch
oneflow/core/autograd/gradient_funcs/tf_pool.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
namespace
{
struct
TFPoolCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
size_t
input_index
=
0
;
size_t
output_index
=
0
;
std
::
string
data_format
;
std
::
string
padding
;
std
::
vector
<
int32_t
>
padding_before
;
std
::
vector
<
int32_t
>
padding_after
;
std
::
vector
<
int32_t
>
pool_size
;
std
::
vector
<
int32_t
>
strides
;
bool
ceil_mode
=
false
;
};
class
TFPoolNdGrad
:
public
OpExprGradFunction
<
TFPoolCaptureState
>
{
public:
virtual
~
TFPoolNdGrad
()
=
default
;
using
OpExprGradFunction
<
TFPoolCaptureState
>::
Init
;
Maybe
<
void
>
Init
(
const
OpExpr
&
op
,
const
std
::
string
&
mode
);
Maybe
<
void
>
Capture
(
TFPoolCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
TFPoolCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
std
::
string
mode_
;
AttrMap
base_attrs_
;
};
Maybe
<
void
>
TFPoolNdGrad
::
Init
(
const
OpExpr
&
op
,
const
std
::
string
&
mode
)
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
mode_
=
mode
;
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
TFPoolNdGrad
::
Capture
(
TFPoolCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
input_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ctx
->
output_index
=
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
0
));
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
padding
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"padding"
));
ctx
->
padding_before
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"padding_before"
));
ctx
->
padding_after
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"padding_after"
));
ctx
->
pool_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"pool_size"
));
ctx
->
strides
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"strides"
));
ctx
->
ceil_mode
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"ceil_mode"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
TFPoolNdGrad
::
Apply
(
const
TFPoolCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
int32_t
ndims
=
ctx
->
pool_size
.
size
();
const
auto
&
input
=
ctx
->
SavedTensors
().
at
(
ctx
->
input_index
);
const
auto
&
output
=
ctx
->
SavedTensors
().
at
(
ctx
->
output_index
);
in_grads
->
resize
(
1
);
(
*
in_grads
)[
0
]
=
JUST
(
functional
::
TFPoolNdGrad
(
input
,
output
,
out_grads
[
0
],
mode_
,
ndims
,
ctx
->
data_format
,
ctx
->
padding
,
ctx
->
padding_before
,
ctx
->
padding_after
,
ctx
->
pool_size
,
ctx
->
strides
,
ctx
->
ceil_mode
));
return
Maybe
<
void
>::
Ok
();
}
}
// namespace
class
TFMaxPoolNdGrad
final
:
public
TFPoolNdGrad
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
TFPoolNdGrad
::
Init
(
op
,
"tf_max"
);
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"tf_max_pool_1d"
,
TFMaxPoolNdGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"tf_max_pool_2d"
,
TFMaxPoolNdGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"tf_max_pool_3d"
,
TFMaxPoolNdGrad
);
class
TFAvgPoolNdGrad
final
:
public
TFPoolNdGrad
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
TFPoolNdGrad
::
Init
(
op
,
"tf_avg"
);
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"tf_avg_pool_1d"
,
TFAvgPoolNdGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"tf_avg_pool_2d"
,
TFAvgPoolNdGrad
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"tf_avg_pool_3d"
,
TFAvgPoolNdGrad
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/to_contiguous.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace
oneflow
{
namespace
one
{
struct
ToContiguousCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
};
class
ToContiguous
:
public
OpExprGradFunction
<
ToContiguousCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ToContiguousCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
[
0
]
->
requires_grad
();
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ToContiguousCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
(
*
in_grads
)[
0
]
=
out_grads
[
0
];
}
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"to_contiguous"
,
ToContiguous
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/transpose.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
TransposeCaptureState
:
public
AutoGradCaptureState
{
std
::
vector
<
int32_t
>
perm
;
bool
requires_grad
;
};
class
Transpose
:
public
OpExprGradFunction
<
TransposeCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
TransposeCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
TransposeCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
Transpose
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Transpose
::
Capture
(
TransposeCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
perm
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"perm"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Transpose
::
Apply
(
const
TransposeCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
std
::
vector
<
int32_t
>
grad_perm
;
grad_perm
.
resize
(
ctx
->
perm
.
size
());
FOR_RANGE
(
int32_t
,
i
,
0
,
ctx
->
perm
.
size
())
{
grad_perm
.
at
(
ctx
->
perm
.
at
(
i
))
=
i
;
}
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Transpose
(
out_grads
.
at
(
0
),
grad_perm
));
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"transpose"
,
Transpose
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/tril.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
TrilCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
int64_t
diagonal
=
0
;
};
class
Tril
:
public
OpExprGradFunction
<
TrilCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
TrilCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
TrilCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
Tril
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Tril
::
Capture
(
TrilCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
diagonal
=
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"diagonal"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Tril
::
Apply
(
const
TrilCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Tril
(
out_grads
.
at
(
0
),
ctx
->
diagonal
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"tril"
,
Tril
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/triu.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
TriuCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
int64_t
diagonal
;
};
class
Triu
:
public
OpExprGradFunction
<
TriuCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
TriuCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
TriuCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
Triu
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Triu
::
Capture
(
TriuCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
diagonal
=
JUST
(
composed_attrs
.
GetAttr
<
int64_t
>
(
"diagonal"
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Triu
::
Apply
(
const
TriuCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
ctx
->
requires_grad
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Triu
(
out_grads
.
at
(
0
),
ctx
->
diagonal
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"triu"
,
Triu
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/two_stage_reduce.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
enum
class
ReduceMode
:
int32_t
{
kMin
=
0
,
kMax
=
1
,
};
struct
ReduceDeviceCaptureState
:
public
AutoGradCaptureState
{
std
::
vector
<
int32_t
>
axis
;
bool
requires_grad
=
false
;
size_t
mask_index
=
-
1
;
size_t
count_index
=
-
1
;
};
template
<
ReduceMode
mode
>
class
ReduceDevice
:
public
OpExprGradFunction
<
ReduceDeviceCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ReduceDeviceCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
axis
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"axis"
));
ctx
->
mask_index
=
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
1
));
// mask
ctx
->
count_index
=
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
2
));
// count
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ReduceDeviceCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
3
);
// NOLINT(maybe-need-error-msg)
const
auto
&
mask
=
ctx
->
SavedTensors
().
at
(
ctx
->
mask_index
);
const
auto
&
count
=
ctx
->
SavedTensors
().
at
(
ctx
->
count_index
);
in_grads
->
resize
(
1
);
if
(
mode
==
ReduceMode
::
kMin
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ReduceMinDeviceStageGrad
(
out_grads
.
at
(
0
),
mask
,
count
,
ctx
->
axis
));
}
else
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ReduceMaxDeviceStageGrad
(
out_grads
.
at
(
0
),
mask
,
count
,
ctx
->
axis
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"reduce_min_device_stage"
,
ReduceDevice
<
ReduceMode
::
kMin
>
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"reduce_max_device_stage"
,
ReduceDevice
<
ReduceMode
::
kMax
>
);
struct
ReduceGlobalCaptureState
:
public
AutoGradCaptureState
{
std
::
vector
<
int32_t
>
axis
;
bool
requires_grad
=
false
;
bool
keepdims
=
false
;
size_t
mask_index
=
-
1
;
size_t
device_count_index
=
-
1
;
};
template
<
ReduceMode
mode
>
class
ReduceGlobal
:
public
OpExprGradFunction
<
ReduceGlobalCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
ReduceGlobalCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
axis
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"axis"
));
ctx
->
keepdims
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"keepdims"
));
ctx
->
mask_index
=
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
1
));
// mask
ctx
->
device_count_index
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
// device_count
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
ReduceGlobalCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
2
);
// NOLINT(maybe-need-error-msg)
const
auto
&
mask
=
ctx
->
SavedTensors
().
at
(
ctx
->
mask_index
);
const
auto
&
device_count
=
ctx
->
SavedTensors
().
at
(
ctx
->
device_count_index
);
in_grads
->
resize
(
2
);
if
(
mode
==
ReduceMode
::
kMin
)
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ReduceMinGlobalStageGrad
(
out_grads
.
at
(
0
),
mask
,
device_count
,
ctx
->
axis
,
ctx
->
keepdims
));
}
else
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ReduceMaxGlobalStageGrad
(
out_grads
.
at
(
0
),
mask
,
device_count
,
ctx
->
axis
,
ctx
->
keepdims
));
}
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"reduce_min_global_stage"
,
ReduceGlobal
<
ReduceMode
::
kMin
>
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"reduce_max_global_stage"
,
ReduceGlobal
<
ReduceMode
::
kMax
>
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/unfold.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
UnfoldInterpState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
true
;
std
::
string
data_format
=
"channels_first"
;
std
::
vector
<
int32_t
>
output_size
;
std
::
vector
<
int32_t
>
kernel_size
;
std
::
vector
<
int32_t
>
dilation_rate
;
std
::
vector
<
int32_t
>
padding
;
std
::
vector
<
int32_t
>
strides
;
};
class
Unfold
:
public
OpExprGradFunction
<
UnfoldInterpState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
UnfoldInterpState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
UnfoldInterpState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
Unfold
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Unfold
::
Capture
(
UnfoldInterpState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
std
::
vector
<
int32_t
>
out_shape
(
2
);
const
std
::
shared_ptr
<
Tensor
>&
x
=
inputs
.
at
(
0
);
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
kernel_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"kernel_size"
));
ctx
->
dilation_rate
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"dilation_rate"
));
ctx
->
padding
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"padding"
));
ctx
->
strides
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"strides"
));
// Only support 4-d Tensor Input.
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
out_shape
.
at
(
i
)
=
(
x
->
shape
()
->
At
(
i
+
2
));
}
ctx
->
output_size
=
out_shape
;
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Unfold
::
Apply
(
const
UnfoldInterpState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Fold
(
out_grads
.
at
(
0
),
ctx
->
data_format
,
ctx
->
output_size
,
ctx
->
kernel_size
,
ctx
->
dilation_rate
,
ctx
->
padding
,
ctx
->
strides
));
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"unfold"
,
Unfold
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/unfold_tensor.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
UnfoldTensorCaptureState
:
public
AutoGradCaptureState
{
int32_t
dimension
=
-
1
;
int32_t
size
=
-
1
;
int32_t
step
=
-
1
;
bool
requires_grad
=
false
;
};
class
UnfoldTensor
:
public
OpExprGradFunction
<
UnfoldTensorCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
UnfoldTensorCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
UnfoldTensorCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
std
::
shared_ptr
<
OpExpr
>
grad_op_
;
};
Maybe
<
void
>
UnfoldTensor
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
UnfoldTensor
::
Capture
(
UnfoldTensorCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
dimension
=
JUST
(
composed_attrs
.
GetAttr
<
int32_t
>
(
"dimension"
));
ctx
->
size
=
JUST
(
composed_attrs
.
GetAttr
<
int32_t
>
(
"size"
));
ctx
->
step
=
JUST
(
composed_attrs
.
GetAttr
<
int32_t
>
(
"step"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
UnfoldTensor
::
Apply
(
const
UnfoldTensorCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
auto
&
in
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
UnfoldTensorGrad
(
out_grads
.
at
(
0
),
in
,
ctx
->
dimension
,
ctx
->
size
,
ctx
->
step
));
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"unfold_tensor"
,
UnfoldTensor
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/unsqueeze.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/job/lazy_mode.h"
namespace
oneflow
{
namespace
one
{
struct
UnsqueezeCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
Shape
shape
;
};
class
Unsqueeze
:
public
OpExprGradFunction
<
UnsqueezeCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
UnsqueezeCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
UnsqueezeCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
Unsqueeze
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Unsqueeze
::
Capture
(
UnsqueezeCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
LazyMode
::
is_enabled
())
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
}
else
{
ctx
->
shape
=
*
(
inputs
.
at
(
0
)
->
shape
());
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Unsqueeze
::
Apply
(
const
UnsqueezeCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
in_grads
->
resize
(
1
);
if
(
LazyMode
::
is_enabled
())
{
const
auto
&
like
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ReshapeLike
(
out_grads
.
at
(
0
),
like
));
}
else
{
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Reshape
(
out_grads
.
at
(
0
),
ctx
->
shape
));
}
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"expand_dims"
,
Unsqueeze
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/upsample.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/common/container_util.h"
namespace
oneflow
{
namespace
one
{
struct
UpsampleCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
double
height_scale
=
0.0
;
double
width_scale
=
0.0
;
float
align_corners
;
std
::
string
data_format
;
std
::
string
interpolation
;
};
class
Upsample
:
public
OpExprGradFunction
<
UpsampleCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
UpsampleCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
UpsampleCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
std
::
shared_ptr
<
OpExpr
>
grad_op_
;
};
Maybe
<
void
>
Upsample
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Upsample
::
Capture
(
UpsampleCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
interpolation
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"interpolation"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Upsample
::
Apply
(
const
UpsampleCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleGrad
(
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
x
,
ctx
->
height_scale
,
ctx
->
width_scale
,
ctx
->
align_corners
,
ctx
->
data_format
,
ctx
->
interpolation
));
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"upsample"
,
Upsample
);
struct
UpsampleNearest2DCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
double
height_scale
=
0.0
;
double
width_scale
=
0.0
;
std
::
vector
<
int64_t
>
output_size
;
std
::
string
data_format
;
};
class
UpsampleNearest2D
:
public
OpExprGradFunction
<
UpsampleNearest2DCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
UpsampleNearest2DCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
if
(
base_attrs_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
())
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
UpsampleNearest2DCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleNearest2DGrad
(
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
x
,
ctx
->
height_scale
,
ctx
->
width_scale
,
ctx
->
output_size
,
ctx
->
data_format
));
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"upsample_nearest_2d"
,
UpsampleNearest2D
);
struct
UpsampleBilinear2DCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
double
height_scale
=
0.0
;
double
width_scale
=
0.0
;
bool
align_corners
;
std
::
vector
<
int64_t
>
output_size
;
std
::
string
data_format
;
};
class
UpsampleBilinear2D
:
public
OpExprGradFunction
<
UpsampleBilinear2DCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
UpsampleBilinear2DCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
if
(
base_attrs_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
())
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
UpsampleBilinear2DCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleBilinear2DGrad
(
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
x
,
ctx
->
height_scale
,
ctx
->
width_scale
,
ctx
->
align_corners
,
ctx
->
output_size
,
ctx
->
data_format
));
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"upsample_bilinear_2d"
,
UpsampleBilinear2D
);
struct
UpsampleLinear1DCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
double
scale_factor
=
0.0
;
bool
align_corners
;
std
::
vector
<
int64_t
>
output_size
;
std
::
string
data_format
;
};
class
UpsampleLinear1D
:
public
OpExprGradFunction
<
UpsampleLinear1DCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
UpsampleLinear1DCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
scale_factor
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"scale_factor"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
if
(
base_attrs_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
())
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
UpsampleLinear1DCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleLinear1DGrad
(
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
x
,
ctx
->
scale_factor
,
ctx
->
align_corners
,
ctx
->
output_size
,
ctx
->
data_format
));
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"upsample_linear_1d"
,
UpsampleLinear1D
);
struct
UpsampleNearest1DCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
double
scale_factor
=
0.0
;
std
::
vector
<
int64_t
>
output_size
;
std
::
string
data_format
;
};
class
UpsampleNearest1D
:
public
OpExprGradFunction
<
UpsampleNearest1DCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
UpsampleNearest1DCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
scale_factor
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"scale_factor"
));
if
(
base_attrs_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
())
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
UpsampleNearest1DCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleNearest1DGrad
(
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
x
,
ctx
->
scale_factor
,
ctx
->
output_size
,
ctx
->
data_format
));
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"upsample_nearest_1d"
,
UpsampleNearest1D
);
struct
UpsampleBicubic2DCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
double
height_scale
=
0.0
;
double
width_scale
=
0.0
;
bool
align_corners
;
std
::
vector
<
int64_t
>
output_size
;
std
::
string
data_format
;
};
class
UpsampleBicubic2D
:
public
OpExprGradFunction
<
UpsampleBicubic2DCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
UpsampleBicubic2DCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
if
(
base_attrs_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
())
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
UpsampleBicubic2DCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleBicubic2DGrad
(
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
x
,
ctx
->
height_scale
,
ctx
->
width_scale
,
ctx
->
align_corners
,
ctx
->
output_size
,
ctx
->
data_format
));
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"upsample_bicubic_2d"
,
UpsampleBicubic2D
);
struct
UpsampleNearest3DCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
double
depth_scale
=
0.0
;
double
height_scale
=
0.0
;
double
width_scale
=
0.0
;
std
::
vector
<
int64_t
>
output_size
;
std
::
string
data_format
;
};
class
UpsampleNearest3D
:
public
OpExprGradFunction
<
UpsampleNearest3DCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
UpsampleNearest3DCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
depth_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"depth_scale"
));
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
if
(
base_attrs_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
())
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
UpsampleNearest3DCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleNearest3DGrad
(
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
x
,
ctx
->
depth_scale
,
ctx
->
height_scale
,
ctx
->
width_scale
,
ctx
->
output_size
,
ctx
->
data_format
));
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"upsample_nearest_3d"
,
UpsampleNearest3D
);
struct
UpsampleTrilinear3DCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
double
depth_scale
=
0.0
;
double
height_scale
=
0.0
;
double
width_scale
=
0.0
;
bool
align_corners
;
std
::
vector
<
int64_t
>
output_size
;
std
::
string
data_format
;
};
class
UpsampleTrilinear3D
:
public
OpExprGradFunction
<
UpsampleTrilinear3DCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
UpsampleTrilinear3DCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
depth_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"depth_scale"
));
ctx
->
height_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"height_scale"
));
ctx
->
width_scale
=
JUST
(
composed_attrs
.
GetAttr
<
double
>
(
"width_scale"
));
ctx
->
align_corners
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"align_corners"
));
if
(
base_attrs_
.
find
(
"output_size"
)
!=
base_attrs_
.
end
())
{
ctx
->
output_size
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"output_size"
));
}
ctx
->
data_format
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
string
>
(
"data_format"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
UpsampleTrilinear3DCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
MutableAttrMap
attrs
;
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
in_grads
->
resize
(
1
);
JUST
(
oneflow
::
VectorAt
(
*
in_grads
,
0
))
=
JUST
(
functional
::
UpsampleTrilinear3DGrad
(
JUST
(
oneflow
::
VectorAt
(
out_grads
,
0
)),
x
,
ctx
->
depth_scale
,
ctx
->
height_scale
,
ctx
->
width_scale
,
ctx
->
align_corners
,
ctx
->
output_size
,
ctx
->
data_format
));
return
Maybe
<
void
>::
Ok
();
}
private:
AttrMap
base_attrs_
;
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"upsample_trilinear_3d"
,
UpsampleTrilinear3D
);
}
// namespace one
}
// namespace oneflow
\ No newline at end of file
oneflow/core/autograd/gradient_funcs/variance.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
VarianceState
:
public
AutoGradCaptureState
{
VarianceState
()
:
requires_grad
(
false
),
unbiased
(
true
),
keepdim
(
false
),
axis
({}){};
bool
requires_grad
;
bool
unbiased
;
bool
keepdim
;
std
::
vector
<
int32_t
>
axis
;
};
class
Variance
:
public
OpExprGradFunction
<
VarianceState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
VarianceState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
VarianceState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
Variance
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
// NOLINT(maybe-need-error-msg)
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Variance
::
Capture
(
VarianceState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
CHECK_EQ_OR_RETURN
(
inputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
outputs
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ComposedAttrMap
composed_attrs
(
attrs
,
base_attrs_
);
ctx
->
keepdim
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"keepdim"
));
ctx
->
unbiased
=
JUST
(
composed_attrs
.
GetAttr
<
bool
>
(
"unbiased"
));
ctx
->
axis
=
JUST
(
composed_attrs
.
GetAttr
<
std
::
vector
<
int32_t
>>
(
"dim"
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Variance
::
Apply
(
const
VarianceState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
// TODO(): replace it using kernel
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
0
);
size_t
correction
=
ctx
->
unbiased
?
1
:
0
;
size_t
elem_cnt
=
1
;
CHECK_OR_RETURN
(
ctx
->
axis
.
size
()
>
0
)
<<
Error
::
RuntimeError
()
<<
"The size of the axis must greater than 0, but got "
<<
ctx
->
axis
.
size
();
for
(
const
auto
&
item
:
ctx
->
axis
)
{
elem_cnt
*=
x
->
shape
()
->
At
(
item
);
}
std
::
shared_ptr
<
Tensor
>
out_grad
=
out_grads
.
at
(
0
);
if
(
ctx
->
keepdim
==
false
)
{
// for broadcast mul
const
std
::
shared_ptr
<
const
Shape
>&
out_grad_shape
=
out_grad
->
shape
();
DimVector
unsqueeze_vector
(
out_grad_shape
->
dim_vec
());
for
(
int
i
=
0
;
i
<
ctx
->
axis
.
size
();
i
++
)
{
unsqueeze_vector
.
insert
(
unsqueeze_vector
.
begin
()
+
ctx
->
axis
.
at
(
i
),
1
);
}
Shape
unsqueeze_shape
(
unsqueeze_vector
);
CHECK_EQ_OR_RETURN
(
unsqueeze_shape
.
elem_cnt
(),
out_grad_shape
->
elem_cnt
())
<<
Error
::
RuntimeError
()
<<
"tensor size mismatch, expected tensor to have the same number of elements, but got "
<<
unsqueeze_shape
.
elem_cnt
()
<<
" and "
<<
out_grad_shape
->
elem_cnt
()
<<
" elements respectively"
;
out_grad
=
JUST
(
functional
::
Reshape
(
out_grad
,
unsqueeze_shape
));
}
in_grads
->
resize
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Mul
(
out_grad
,
JUST
(
functional
::
ScalarMul
(
Scalar
(
2.0
/
(
elem_cnt
-
correction
)),
JUST
(
functional
::
Sub
(
x
,
JUST
(
functional
::
ReduceMean
(
x
,
ctx
->
axis
,
/*keepdim=*/
true
)),
/*alpha=*/
1.0
,
/*inplace=*/
false
))))));
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"var"
,
Variance
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/where.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
WhereCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad_x
;
bool
requires_grad_y
;
};
struct
WhereScalarCaptureState
:
public
AutoGradCaptureState
{
bool
requires_grad
;
};
class
Where
:
public
OpExprGradFunction
<
WhereCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
WhereCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
WhereCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
};
Maybe
<
void
>
Where
::
Init
(
const
OpExpr
&
op
)
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Where
::
Capture
(
WhereCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad_x
=
inputs
.
at
(
1
)
->
requires_grad
();
ctx
->
requires_grad_y
=
inputs
.
at
(
2
)
->
requires_grad
();
if
((
!
ctx
->
requires_grad_x
)
&&
(
!
ctx
->
requires_grad_y
))
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
// condition
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
// x
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
2
));
// y
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Where
::
Apply
(
const
WhereCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
if
((
!
ctx
->
requires_grad_x
)
&&
(
!
ctx
->
requires_grad_y
))
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
condition
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
1
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
y
=
ctx
->
SavedTensors
().
at
(
2
);
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>
zero_out
=
JUST
(
functional
::
ZerosLike
(
x
));
in_grads
->
resize
(
3
);
if
(
ctx
->
requires_grad_x
)
{
auto
broad_x_grad
=
JUST
(
functional
::
Where
(
condition
,
out_grads
.
at
(
0
),
zero_out
));
in_grads
->
at
(
1
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
broad_x_grad
,
x
));
}
if
(
ctx
->
requires_grad_y
)
{
auto
broad_y_grad
=
JUST
(
functional
::
Where
(
condition
,
zero_out
,
out_grads
.
at
(
0
)));
in_grads
->
at
(
2
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
broad_y_grad
,
y
));
}
return
Maybe
<
void
>::
Ok
();
}
class
WhereScalar
:
public
OpExprGradFunction
<
WhereScalarCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
WhereScalarCaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
ctx
->
requires_grad
=
inputs
.
at
(
1
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
1
));
return
Maybe
<
void
>::
Ok
();
}
};
class
WhereScalarX
:
public
WhereScalar
{
public:
Maybe
<
void
>
Apply
(
const
WhereScalarCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
condition
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
y
=
ctx
->
SavedTensors
().
at
(
1
);
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>
zero_out
=
JUST
(
functional
::
ZerosLike
(
y
));
in_grads
->
resize
(
2
);
auto
broad_y_grad
=
JUST
(
functional
::
Where
(
condition
,
zero_out
,
out_grads
.
at
(
0
)));
in_grads
->
at
(
1
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
broad_y_grad
,
y
));
return
Maybe
<
void
>::
Ok
();
}
};
class
WhereScalarY
:
public
WhereScalar
{
public:
Maybe
<
void
>
Apply
(
const
WhereScalarCaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
1
);
// NOLINT(maybe-need-error-msg)
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
condition
=
ctx
->
SavedTensors
().
at
(
0
);
const
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>&
x
=
ctx
->
SavedTensors
().
at
(
1
);
std
::
shared_ptr
<
oneflow
::
one
::
Tensor
>
zero_out
=
JUST
(
functional
::
ZerosLike
(
x
));
in_grads
->
resize
(
2
);
auto
broad_x_grad
=
JUST
(
functional
::
Where
(
condition
,
out_grads
.
at
(
0
),
zero_out
));
in_grads
->
at
(
1
)
=
JUST
(
functional
::
BroadcastReduceSumLike
(
broad_x_grad
,
x
));
return
Maybe
<
void
>::
Ok
();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"where"
,
Where
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"where_scalar_x"
,
WhereScalarX
);
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"where_scalar_y"
,
WhereScalarY
);
}
// namespace one
}
// namespace oneflow
oneflow/core/boxing/asymmetric_broadcast.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/framework/placement_sbp_util.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/decorator.h"
namespace
oneflow
{
namespace
{
Maybe
<
void
>
RawCheckAsymmetricBroadcast
(
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
,
const
Shape
&
logical_shape
)
{
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
in
->
nd_sbp
()
->
sbp_parallel_size
(),
1
);
CHECK_EQ_OR_RETURN
(
out
->
nd_sbp
()
->
sbp_parallel_size
(),
1
);
CHECK_OR_RETURN
(
NdSbpIsAllBroadcast
(
*
in
->
nd_sbp
()));
CHECK_OR_RETURN
(
NdSbpIsAllBroadcast
(
*
out
->
nd_sbp
()));
CHECK_OR_RETURN
(
out
->
placement
()
->
Bigger
(
*
in
->
placement
())
||
in
->
placement
()
->
Bigger
(
*
out
->
placement
()));
// NOLINTEND(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
static
constexpr
auto
*
CheckAsymmetricBroadcast
=
DECORATE
(
&
RawCheckAsymmetricBroadcast
,
ThreadLocalCachedCopiable
);
Maybe
<
int64_t
>
CalBroadcastRoot
(
Symbol
<
ParallelDesc
>
src_parallel_desc
,
Symbol
<
ParallelDesc
>
dst_parallel_desc
)
{
int64_t
machine_id
=
-
1
;
int64_t
device_id
=
-
1
;
for
(
int64_t
mach_id
:
src_parallel_desc
->
sorted_machine_ids
())
{
bool
machine_and_device_id_inited
=
false
;
for
(
int64_t
dev_id
:
src_parallel_desc
->
sorted_dev_phy_ids
(
mach_id
))
{
if
(
dst_parallel_desc
->
Containing
(
mach_id
,
dev_id
))
{
machine_id
=
mach_id
;
device_id
=
dev_id
;
machine_and_device_id_inited
=
true
;
break
;
}
}
if
(
machine_and_device_id_inited
)
{
break
;
}
}
// Always true, if check failed, there is a bug in oneflow needed to be resolved.
CHECK_OR_RETURN
(
machine_id
!=
-
1
&&
device_id
!=
-
1
)
<<
Error
::
RuntimeError
()
<<
"Calculate the intersection of placements "
"failed during execution of asymmetric broadcast,"
<<
", placement_a: "
<<
*
JUST
(
PlacementToString
(
src_parallel_desc
))
<<
", placement_b: "
<<
*
JUST
(
PlacementToString
(
dst_parallel_desc
))
<<
"! Please submit an issue in `https://github.com/Oneflow-Inc/oneflow/issues` "
"and we will fix it as soon as possible"
;
return
machine_id
;
}
static
constexpr
auto
*
CachedGetBroadcastRoot
=
DECORATE
(
&
CalBroadcastRoot
,
ThreadLocalCached
);
Maybe
<
one
::
UserOpExpr
>
EagerNcclBroadcast
(
Symbol
<
ParallelDesc
>
parallel_desc
,
int64_t
root
)
{
return
one
::
OpBuilder
(
"eager_nccl_broadcast"
,
*
JUST
(
UniqueStr
(
"eager_nccl_broadcast"
)))
.
Input
(
"in"
)
.
Output
(
"out"
)
.
Attr
<
std
::
string
>
(
"parallel_conf"
,
PbMessage2TxtString
(
parallel_desc
->
parallel_conf
()))
.
Attr
<
int64_t
>
(
"root"
,
root
)
.
Build
();
}
static
constexpr
auto
*
CachedEagerNcclBroadcast
=
DECORATE
(
&
EagerNcclBroadcast
,
ThreadLocalCached
);
}
// namespace
Maybe
<
one
::
Tensor
>
AsymmetricBroadcast
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
,
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
{
const
auto
&
in_placement
=
in
->
placement
();
const
auto
&
out_placement
=
out
->
placement
();
const
auto
&
tensor_nd_sbp
=
JUST
(
tensor
->
nd_sbp
());
CHECK_OR_RETURN
(
tensor_nd_sbp
==
in
->
nd_sbp
())
<<
Error
::
RuntimeError
()
<<
"The sbp of input tensor ("
<<
NdSbpToString
(
tensor_nd_sbp
)
<<
") must match the input sbp ("
<<
NdSbpToString
(
in
->
nd_sbp
())
<<
")"
;
const
auto
&
tensor_placement
=
JUST
(
tensor
->
parallel_desc
());
CHECK_OR_RETURN
(
tensor_placement
==
in_placement
)
<<
Error
::
RuntimeError
()
<<
"The placement of input tensor ("
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in_placement
))
<<
")"
;
std
::
shared_ptr
<
one
::
Tensor
>
local_tensor
=
JUST
(
tensor
->
cur_rank_phy_tensor
());
if
(
out
->
placement
()
->
Bigger
(
*
in
->
placement
()))
{
const
auto
&
out_parallel_id
=
JUST
(
GetParallelId4CurrentProcessCtx
(
out_placement
));
if
(
out_parallel_id
->
has_value
())
{
const
auto
&
in_parallel_id
=
JUST
(
GetParallelId4CurrentProcessCtx
(
in_placement
));
if
(
!
in_parallel_id
->
has_value
())
{
const
std
::
string
&
device_type
=
in_placement
->
device_tag
();
local_tensor
=
JUST
(
one
::
functional
::
Empty
(
*
tensor
->
shape
(),
tensor
->
dtype
(),
JUST
(
Device
::
New
(
device_type
)),
/*pin_memory=*/
false
));
}
const
auto
&
broadcast_group
=
JUST
(
GetBroadcastGroup
(
in_placement
,
out_placement
));
Symbol
<
ParallelDesc
>
broadcast_placement_cur_rank
=
JUST
(
MapAt
(
*
broadcast_group
,
GlobalProcessCtx
::
Rank
()));
int64_t
root
=
JUST
(
CachedGetBroadcastRoot
(
in_placement
,
broadcast_placement_cur_rank
));
std
::
shared_ptr
<
one
::
UserOpExpr
>
op_expr
=
JUST
(
CachedEagerNcclBroadcast
(
broadcast_placement_cur_rank
,
root
));
local_tensor
=
JUST
(
one
::
OpInterpUtil
::
Dispatch
<
one
::
Tensor
>
(
*
op_expr
,
{
local_tensor
}));
}
}
return
one
::
functional
::
LocalToConsistent
(
local_tensor
,
out_placement
,
*
JUST
(
GetSbpList
(
out
->
nd_sbp
())),
*
tensor
->
shape
(),
tensor
->
dtype
());
}
COMMAND
(
RegisterBoxingFunction
(
"asymmetric-broadcast"
,
CheckAsymmetricBroadcast
,
&
AsymmetricBroadcast
));
}
// namespace oneflow
oneflow/core/boxing/boxing_dividor.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_
#define ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_
#include <functional>
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/symbol.h"
namespace
oneflow
{
class
PlacedNdSbp
;
class
BoxingDividor
final
{
public:
BoxingDividor
(
const
BoxingDividor
&
)
=
delete
;
BoxingDividor
(
BoxingDividor
&&
)
=
delete
;
~
BoxingDividor
()
=
default
;
using
FunctionT
=
std
::
function
<
Maybe
<
Symbol
<
PlacedNdSbp
>>
(
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
>
;
BoxingDividor
(
const
std
::
string
&
name
,
const
FunctionT
&
function
)
:
name_
(
name
),
function_
(
function
)
{}
const
std
::
string
&
name
()
const
{
return
name_
;
}
Maybe
<
Symbol
<
PlacedNdSbp
>>
operator
()(
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
const
{
return
function_
(
in
,
out
);
}
private:
std
::
string
name_
;
FunctionT
function_
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_H_
oneflow/core/boxing/boxing_dividor_util.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/boxing/boxing_dividor_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/placed_nd_sbp.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/job/parallel_desc.h"
namespace
oneflow
{
namespace
{
Maybe
<
BoxingDividor
>
RawReplaceInDeviceType
(
DeviceType
device_type
)
{
return
std
::
make_shared
<
BoxingDividor
>
(
"ReplaceInDeviceType"
,
[
device_type
](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
const
auto
&
new_placement
=
JUST
(
ReplaceDeviceType
(
in
->
placement
(),
device_type
));
return
PlacedNdSbp
::
New
(
in
->
nd_sbp
(),
new_placement
);
});
}
Maybe
<
BoxingDividor
>
RawReplaceOutDeviceType
(
DeviceType
device_type
)
{
return
std
::
make_shared
<
BoxingDividor
>
(
"ReplaceOutDeviceType"
,
[
device_type
](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
const
auto
&
new_placement
=
JUST
(
ReplaceDeviceType
(
out
->
placement
(),
device_type
));
return
PlacedNdSbp
::
New
(
out
->
nd_sbp
(),
new_placement
);
});
}
}
// namespace
decltype
(
ReplaceInDeviceType
)
ReplaceInDeviceType
=
DECORATE
(
&
RawReplaceInDeviceType
,
ThreadLocalCached
);
decltype
(
ReplaceOutDeviceType
)
ReplaceOutDeviceType
=
DECORATE
(
&
RawReplaceOutDeviceType
,
ThreadLocalCached
);
namespace
{
Maybe
<
Symbol
<
PlacedNdSbp
>>
RawFlattenHierarchy
(
Symbol
<
PlacedNdSbp
>
placed_nd_sbp
)
{
CHECK_GE_OR_RETURN
(
placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel_size
(),
0
)
<<
Error
::
RuntimeError
()
<<
"Invalid nd_sbp with ndim equal 0!"
;
const
auto
&
first_sbp_parallel
=
placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel
(
0
);
for
(
const
auto
&
sbp_parallel
:
placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel
())
{
CHECK_OR_RETURN
(
sbp_parallel
==
first_sbp_parallel
)
<<
Error
::
RuntimeError
()
<<
"Expected all sbps to be on the same in sbp list during flatten sbps list, but find at "
"least two sbps, "
<<
SbpToString
(
first_sbp_parallel
)
<<
" and "
<<
SbpToString
(
sbp_parallel
)
<<
"!"
;
}
std
::
vector
<
Symbol
<
SbpParallel
>>
vec
{
SymbolOf
(
first_sbp_parallel
)};
const
auto
&
flattened_nd_sbp
=
JUST
(
GetNdSbp
(
vec
));
ParallelConf
flattened_parallel_conf
(
placed_nd_sbp
->
placement
()
->
parallel_conf
());
flattened_parallel_conf
.
clear_hierarchy
();
const
auto
&
flattened_placement
=
SymbolOf
(
ParallelDesc
(
flattened_parallel_conf
));
return
JUST
(
PlacedNdSbp
::
New
(
flattened_nd_sbp
,
flattened_placement
));
}
static
constexpr
auto
*
FlattenHierarchy
=
DECORATE
(
&
RawFlattenHierarchy
,
ThreadLocalCached
);
Maybe
<
BoxingDividor
>
RawFlattenInHierarchy
()
{
return
std
::
make_shared
<
BoxingDividor
>
(
"FlattenInHierarchy"
,
[](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
return
FlattenHierarchy
(
in
);
});
}
Maybe
<
Symbol
<
PlacedNdSbp
>>
RawUnflattenHierarchy
(
Symbol
<
PlacedNdSbp
>
in_placed_nd_sbp
,
Symbol
<
PlacedNdSbp
>
out_placed_nd_sbp
)
{
CHECK_GE_OR_RETURN
(
in_placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel_size
(),
0
)
<<
Error
::
RuntimeError
()
<<
"Invalid nd_sbp with ndim equal 0!"
;
CHECK_GE_OR_RETURN
(
out_placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel_size
(),
0
)
<<
Error
::
RuntimeError
()
<<
"Invalid nd_sbp with ndim equal 0!"
;
const
auto
&
in_sbp_parallel
=
in_placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel
(
0
);
NdSbp
unflattened_nd_sbp
;
for
(
int64_t
i
=
0
;
i
<
out_placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel_size
();
++
i
)
{
unflattened_nd_sbp
.
mutable_sbp_parallel
()
->
Add
()
->
CopyFrom
(
in_sbp_parallel
);
}
return
JUST
(
PlacedNdSbp
::
New
(
SymbolOf
(
unflattened_nd_sbp
),
out_placed_nd_sbp
->
placement
()));
}
static
constexpr
auto
*
UnflattenHierarchy
=
DECORATE
(
&
RawUnflattenHierarchy
,
ThreadLocalCached
);
Maybe
<
BoxingDividor
>
RawUnflattenInHierarchy
()
{
return
std
::
make_shared
<
BoxingDividor
>
(
"UnflattenInHierarchy"
,
[](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
return
UnflattenHierarchy
(
in
,
out
);
});
}
Maybe
<
BoxingDividor
>
RawUnflattenOutHierarchy
()
{
return
std
::
make_shared
<
BoxingDividor
>
(
"UnflattenOutHierarchy"
,
[](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
return
UnflattenHierarchy
(
out
,
in
);
});
}
}
// namespace
decltype
(
FlattenInHierarchy
)
FlattenInHierarchy
=
DECORATE
(
&
RawFlattenInHierarchy
,
ThreadLocalCached
);
decltype
(
UnflattenInHierarchy
)
UnflattenInHierarchy
=
DECORATE
(
&
RawUnflattenInHierarchy
,
ThreadLocalCached
);
decltype
(
UnflattenOutHierarchy
)
UnflattenOutHierarchy
=
DECORATE
(
&
RawUnflattenOutHierarchy
,
ThreadLocalCached
);
namespace
{
Maybe
<
Symbol
<
NdSbp
>>
GetAllPartialSumNdSbp
(
int64_t
ndim
)
{
NdSbp
partial_sum_nd_sbp
;
for
(
int64_t
i
=
0
;
i
<
ndim
;
++
i
)
{
partial_sum_nd_sbp
.
mutable_sbp_parallel
()
->
Add
()
->
mutable_partial_sum_parallel
();
}
return
SymbolOf
(
partial_sum_nd_sbp
);
}
auto
*
CachedGetAllPartialSumNdSbp
=
DECORATE
(
&
GetAllPartialSumNdSbp
,
ThreadLocalCached
);
Maybe
<
Symbol
<
PlacedNdSbp
>>
RawReplaceNdSbpWithPartialSum
(
Symbol
<
PlacedNdSbp
>
placed_nd_sbp
)
{
Symbol
<
NdSbp
>
partial_sum_nd_sbp
=
JUST
(
CachedGetAllPartialSumNdSbp
(
placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel_size
()));
return
JUST
(
PlacedNdSbp
::
New
(
partial_sum_nd_sbp
,
placed_nd_sbp
->
placement
()));
}
static
constexpr
auto
*
ReplaceNdSbpWithPartialSum
=
DECORATE
(
&
RawReplaceNdSbpWithPartialSum
,
ThreadLocalCached
);
Maybe
<
BoxingDividor
>
RawOutPlacementAndPartialSum
()
{
return
std
::
make_shared
<
BoxingDividor
>
(
"OutPlacementAndPartialSum"
,
[](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
return
ReplaceNdSbpWithPartialSum
(
out
);
});
}
}
// namespace
decltype
(
OutPlacementAndPartialSum
)
OutPlacementAndPartialSum
=
DECORATE
(
&
RawOutPlacementAndPartialSum
,
ThreadLocalCached
);
namespace
{
Maybe
<
Symbol
<
NdSbp
>>
GetAllBroadcastNdSbp
(
int64_t
ndim
)
{
NdSbp
broadcast_nd_sbp
;
for
(
int64_t
i
=
0
;
i
<
ndim
;
++
i
)
{
broadcast_nd_sbp
.
mutable_sbp_parallel
()
->
Add
()
->
mutable_broadcast_parallel
();
}
return
SymbolOf
(
broadcast_nd_sbp
);
}
auto
*
CachedGetAllBroadcastNdSbp
=
DECORATE
(
&
GetAllBroadcastNdSbp
,
ThreadLocalCached
);
Maybe
<
Symbol
<
PlacedNdSbp
>>
RawReplaceNdSbpWithBroadcast
(
Symbol
<
PlacedNdSbp
>
placed_nd_sbp
)
{
Symbol
<
NdSbp
>
broadcast_nd_sbp
=
JUST
(
CachedGetAllBroadcastNdSbp
(
placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel_size
()));
return
JUST
(
PlacedNdSbp
::
New
(
broadcast_nd_sbp
,
placed_nd_sbp
->
placement
()));
}
static
constexpr
auto
*
ReplaceNdSbpWithBroadcast
=
DECORATE
(
&
RawReplaceNdSbpWithBroadcast
,
ThreadLocalCached
);
Maybe
<
BoxingDividor
>
RawInPlacementAndBroadcast
()
{
return
std
::
make_shared
<
BoxingDividor
>
(
"InPlacementAndBroadcast"
,
[](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
return
ReplaceNdSbpWithBroadcast
(
in
);
});
}
Maybe
<
BoxingDividor
>
RawOutPlacementAndBroadcast
()
{
return
std
::
make_shared
<
BoxingDividor
>
(
"OutPlacementAndBroadcast"
,
[](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
return
ReplaceNdSbpWithBroadcast
(
out
);
});
}
}
// namespace
decltype
(
InPlacementAndBroadcast
)
InPlacementAndBroadcast
=
DECORATE
(
&
RawInPlacementAndBroadcast
,
ThreadLocalCached
);
decltype
(
OutPlacementAndBroadcast
)
OutPlacementAndBroadcast
=
DECORATE
(
&
RawOutPlacementAndBroadcast
,
ThreadLocalCached
);
namespace
{
Maybe
<
Symbol
<
NdSbp
>>
GetSplitNdSbp
(
int64_t
axis
)
{
NdSbp
split_nd_sbp
;
split_nd_sbp
.
mutable_sbp_parallel
()
->
Add
()
->
mutable_split_parallel
()
->
set_axis
(
axis
);
return
SymbolOf
(
split_nd_sbp
);
}
auto
*
CachedGetSplitNdSbp
=
DECORATE
(
&
GetSplitNdSbp
,
ThreadLocalCached
);
Maybe
<
BoxingDividor
>
RawInPlacementAndSplit
(
int64_t
axis
)
{
return
std
::
make_shared
<
BoxingDividor
>
(
"InPlacementAndSplit"
,
[
=
](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
Symbol
<
NdSbp
>
split_nd_sbp
=
JUST
(
CachedGetSplitNdSbp
(
axis
));
return
PlacedNdSbp
::
New
(
split_nd_sbp
,
in
->
placement
());
});
}
Maybe
<
BoxingDividor
>
RawOutPlacementAndSplit
(
int64_t
axis
)
{
return
std
::
make_shared
<
BoxingDividor
>
(
"OutPlacementAndSplit"
,
[
=
](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
Symbol
<
NdSbp
>
split_nd_sbp
=
JUST
(
CachedGetSplitNdSbp
(
axis
));
return
PlacedNdSbp
::
New
(
split_nd_sbp
,
out
->
placement
());
});
}
}
// namespace
decltype
(
InPlacementAndSplit
)
InPlacementAndSplit
=
DECORATE
(
&
RawInPlacementAndSplit
,
ThreadLocalCached
);
decltype
(
OutPlacementAndSplit
)
OutPlacementAndSplit
=
DECORATE
(
&
RawOutPlacementAndSplit
,
ThreadLocalCached
);
namespace
{
Maybe
<
Symbol
<
ParallelDesc
>>
GetFisrtDeviceOfPlacement
(
Symbol
<
ParallelDesc
>
placement
)
{
ParallelConf
parallel_conf
;
int64_t
machine_id
=
JUST
(
placement
->
MachineId4ParallelId
(
0
));
int64_t
device_id
=
JUST
(
placement
->
DeviceId4ParallelId
(
0
));
parallel_conf
.
set_device_tag
(
placement
->
device_tag
());
parallel_conf
.
add_device_name
(
std
::
string
(
"@"
)
+
std
::
to_string
(
machine_id
)
+
":"
+
std
::
to_string
(
device_id
));
for
(
int64_t
i
=
0
;
i
<
placement
->
hierarchy
()
->
NumAxes
();
++
i
)
{
parallel_conf
.
mutable_hierarchy
()
->
add_dim
(
1
);
}
std
::
shared_ptr
<
ParallelDesc
>
parallel_desc
;
JUST
(
PhysicalRun
([
&
parallel_desc
,
&
parallel_conf
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
parallel_desc
=
JUST
(
builder
->
GetParallelDescSymbol
(
parallel_conf
));
return
Maybe
<
void
>::
Ok
();
}));
return
SymbolOf
(
*
parallel_desc
);
}
Maybe
<
BoxingDividor
>
RawInFirstDeviceAndAllBroadcast
()
{
return
std
::
make_shared
<
BoxingDividor
>
(
"InFirstDeviceAndAllBroadcast"
,
[](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
return
PlacedNdSbp
::
New
(
JUST
(
CachedGetAllBroadcastNdSbp
(
in
->
nd_sbp
()
->
sbp_parallel_size
())),
JUST
(
GetFisrtDeviceOfPlacement
(
in
->
placement
())));
});
}
Maybe
<
BoxingDividor
>
RawOutFirstDeviceAndAllBroadcast
()
{
return
std
::
make_shared
<
BoxingDividor
>
(
"OutFirstDeviceAndAllBroadcast"
,
[](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
return
PlacedNdSbp
::
New
(
JUST
(
CachedGetAllBroadcastNdSbp
(
out
->
nd_sbp
()
->
sbp_parallel_size
())),
JUST
(
GetFisrtDeviceOfPlacement
(
out
->
placement
())));
});
}
}
// namespace
decltype
(
InFirstDeviceAndAllBroadcast
)
InFirstDeviceAndAllBroadcast
=
DECORATE
(
&
RawInFirstDeviceAndAllBroadcast
,
ThreadLocalCached
);
decltype
(
OutFirstDeviceAndAllBroadcast
)
OutFirstDeviceAndAllBroadcast
=
DECORATE
(
&
RawOutFirstDeviceAndAllBroadcast
,
ThreadLocalCached
);
namespace
{
Maybe
<
Symbol
<
PlacedNdSbp
>>
RawPlacementAndRepeatFirstSbp
(
Symbol
<
PlacedNdSbp
>
placed_nd_sbp
)
{
const
auto
&
first_sbp_parallel
=
placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel
(
0
);
NdSbp
out_nd_sbp
;
for
(
int64_t
i
=
0
;
i
<
placed_nd_sbp
->
nd_sbp
()
->
sbp_parallel_size
();
++
i
)
{
out_nd_sbp
.
mutable_sbp_parallel
()
->
Add
()
->
CopyFrom
(
first_sbp_parallel
);
}
return
JUST
(
PlacedNdSbp
::
New
(
SymbolOf
(
out_nd_sbp
),
placed_nd_sbp
->
placement
()));
}
static
constexpr
auto
*
PlacementAndRepeatFirstSbp
=
DECORATE
(
&
RawPlacementAndRepeatFirstSbp
,
ThreadLocalCached
);
Maybe
<
BoxingDividor
>
RawInPlacementAndRepeatFirstSbp
()
{
return
std
::
make_shared
<
BoxingDividor
>
(
"InPlacementAndRepeatFirstSbp"
,
[](
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
->
Maybe
<
Symbol
<
PlacedNdSbp
>>
{
return
PlacementAndRepeatFirstSbp
(
in
);
});
}
}
// namespace
decltype
(
InPlacementAndRepeatFirstSbp
)
InPlacementAndRepeatFirstSbp
=
DECORATE
(
&
RawInPlacementAndRepeatFirstSbp
,
ThreadLocalCached
);
}
// namespace oneflow
oneflow/core/boxing/boxing_dividor_util.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_
#define ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_
#include "oneflow/core/common/device_type.pb.h"
#include "oneflow/core/boxing/boxing_dividor.h"
namespace
oneflow
{
extern
Maybe
<
BoxingDividor
>
(
*
ReplaceInDeviceType
)(
DeviceType
device_type
);
extern
Maybe
<
BoxingDividor
>
(
*
ReplaceOutDeviceType
)(
DeviceType
device_type
);
extern
Maybe
<
BoxingDividor
>
(
*
FlattenInHierarchy
)();
extern
Maybe
<
BoxingDividor
>
(
*
UnflattenInHierarchy
)();
extern
Maybe
<
BoxingDividor
>
(
*
UnflattenOutHierarchy
)();
extern
Maybe
<
BoxingDividor
>
(
*
OutPlacementAndPartialSum
)();
extern
Maybe
<
BoxingDividor
>
(
*
InPlacementAndBroadcast
)();
extern
Maybe
<
BoxingDividor
>
(
*
OutPlacementAndBroadcast
)();
extern
Maybe
<
BoxingDividor
>
(
*
InPlacementAndSplit
)(
int64_t
axis
);
extern
Maybe
<
BoxingDividor
>
(
*
OutPlacementAndSplit
)(
int64_t
axis
);
extern
Maybe
<
BoxingDividor
>
(
*
InFirstDeviceAndAllBroadcast
)();
extern
Maybe
<
BoxingDividor
>
(
*
OutFirstDeviceAndAllBroadcast
)();
extern
Maybe
<
BoxingDividor
>
(
*
InPlacementAndRepeatFirstSbp
)();
}
// namespace oneflow
#endif // ONEFLOW_CORE_BOXING_BOXING_DIVIDOR_UTIL_H_
oneflow/core/boxing/boxing_interpreter_status.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/placed_nd_sbp.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/boxing/boxing_interpreter_status.h"
namespace
oneflow
{
namespace
{
Maybe
<
BoxingInterpreterStatus
>
RawMakeBoxingInterpreterStatus
(
const
std
::
string
&
boxing_name
,
const
Shape
&
logical_shape
,
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
{
std
::
vector
<
std
::
string
>
sorted_boxing_names
{
boxing_name
};
BoxingInterpreterStatus
status
(
SymbolOf
(
sorted_boxing_names
),
logical_shape
,
in
,
out
);
return
status
;
}
Maybe
<
BoxingInterpreterStatus
>
RawMakeComposedBoxingInterpreterStatus
(
const
std
::
shared_ptr
<
BoxingInterpreterStatus
>&
lhs_status
,
const
std
::
shared_ptr
<
BoxingInterpreterStatus
>&
rhs_status
)
{
CHECK_OR_RETURN
(
lhs_status
->
dst_placed_nd_sbp
()
==
rhs_status
->
src_placed_nd_sbp
())
// always true
<<
Error
::
RuntimeError
()
<<
"Intermediate placed_nd_sbp must be equal when compose boxing interpreter status"
<<
". lhs_status.dst_nd_sbp: "
<<
NdSbpToString
(
lhs_status
->
dst_placed_nd_sbp
()
->
nd_sbp
())
<<
", rhs_status.dst_nd_sbp: "
<<
NdSbpToString
(
rhs_status
->
src_placed_nd_sbp
()
->
nd_sbp
())
<<
", lhs_status.dst_placement: "
<<
*
JUST
(
PlacementToString
(
lhs_status
->
dst_placed_nd_sbp
()
->
placement
()))
<<
", rhs_status.dst_placement: "
<<
*
JUST
(
PlacementToString
(
rhs_status
->
src_placed_nd_sbp
()
->
placement
()));
CHECK_OR_RETURN
(
lhs_status
->
logical_shape
()
==
rhs_status
->
logical_shape
())
// always true
<<
Error
::
RuntimeError
()
<<
"Logical_shape must be equal when compose boxing interpreter status"
<<
". lhs_status.logical_shape: "
<<
(
lhs_status
->
logical_shape
().
ToString
())
<<
". rhs_status.logical_shape: "
<<
(
rhs_status
->
logical_shape
().
ToString
());
std
::
vector
<
std
::
string
>
sorted_boxing_names
(
*
lhs_status
->
sorted_boxing_names
());
sorted_boxing_names
.
insert
(
sorted_boxing_names
.
end
(),
rhs_status
->
sorted_boxing_names
()
->
begin
(),
rhs_status
->
sorted_boxing_names
()
->
end
());
std
::
vector
<
Symbol
<
PlacedNdSbp
>>
mid_placed_nd_sbp
(
*
lhs_status
->
mid_placed_nd_sbp
());
mid_placed_nd_sbp
.
emplace_back
(
lhs_status
->
dst_placed_nd_sbp
());
mid_placed_nd_sbp
.
insert
(
mid_placed_nd_sbp
.
end
(),
rhs_status
->
mid_placed_nd_sbp
()
->
begin
(),
rhs_status
->
mid_placed_nd_sbp
()
->
end
());
BoxingInterpreterStatus
status
(
sorted_boxing_names
,
lhs_status
->
logical_shape
(),
lhs_status
->
src_placed_nd_sbp
(),
SymbolOf
(
mid_placed_nd_sbp
),
rhs_status
->
dst_placed_nd_sbp
());
return
status
;
}
}
// namespace
decltype
(
MakeBoxingInterpreterStatus
)
MakeBoxingInterpreterStatus
=
DECORATE
(
&
RawMakeBoxingInterpreterStatus
,
ThreadLocalCachedCopiable
);
decltype
(
MakeComposedBoxingInterpreterStatus
)
MakeComposedBoxingInterpreterStatus
=
DECORATE
(
&
RawMakeComposedBoxingInterpreterStatus
,
ThreadLocalCachedCopiable
);
namespace
{
Maybe
<
std
::
string
>
RawGetNdSbpRouting
(
Symbol
<
PlacedNdSbp
>
src_placed_nd_sbp
,
Symbol
<
std
::
vector
<
Symbol
<
PlacedNdSbp
>>>
mid_placed_nd_sbp
,
Symbol
<
PlacedNdSbp
>
dst_placed_nd_sbp
)
{
std
::
ostringstream
ss
;
ss
<<
NdSbpToString
(
src_placed_nd_sbp
->
nd_sbp
());
for
(
const
auto
&
placed_nd_sbp
:
*
mid_placed_nd_sbp
)
{
ss
<<
" -> "
<<
NdSbpToString
(
placed_nd_sbp
->
nd_sbp
());
}
ss
<<
" -> "
<<
NdSbpToString
(
dst_placed_nd_sbp
->
nd_sbp
());
return
ss
.
str
();
}
Maybe
<
std
::
string
>
RawGetPlacementRouting
(
Symbol
<
PlacedNdSbp
>
src_placed_nd_sbp
,
Symbol
<
std
::
vector
<
Symbol
<
PlacedNdSbp
>>>
mid_placed_nd_sbp
,
Symbol
<
PlacedNdSbp
>
dst_placed_nd_sbp
)
{
std
::
ostringstream
ss
;
ss
<<
*
JUST
(
PlacementToString
(
src_placed_nd_sbp
->
placement
()));
for
(
const
auto
&
placed_nd_sbp
:
*
mid_placed_nd_sbp
)
{
ss
<<
" -> "
<<
*
JUST
(
PlacementToString
(
placed_nd_sbp
->
placement
()));
}
ss
<<
" -> "
<<
*
JUST
(
PlacementToString
(
dst_placed_nd_sbp
->
placement
()));
return
ss
.
str
();
}
Maybe
<
std
::
string
>
RawGetBoxingDesc
(
Symbol
<
std
::
vector
<
std
::
string
>>
sorted_boxing_names
)
{
CHECK_OR_RETURN
(
!
sorted_boxing_names
->
empty
())
// always true
<<
Error
::
RuntimeError
()
<<
"boxing_names of eager boxing status can't be empty!"
;
std
::
ostringstream
ss
;
ss
<<
sorted_boxing_names
->
at
(
0
);
for
(
size_t
i
=
1
;
i
<
sorted_boxing_names
->
size
();
++
i
)
{
ss
<<
" -> "
<<
sorted_boxing_names
->
at
(
i
);
}
return
ss
.
str
();
}
static
constexpr
auto
*
GetNdSbpRouting
=
DECORATE
(
&
RawGetNdSbpRouting
,
ThreadLocalCached
);
static
constexpr
auto
*
GetPlacementRouting
=
DECORATE
(
&
RawGetPlacementRouting
,
ThreadLocalCached
);
static
constexpr
auto
*
GetBoxingDesc
=
DECORATE
(
&
RawGetBoxingDesc
,
ThreadLocalCached
);
}
// namespace
const
std
::
string
&
BoxingInterpreterStatus
::
boxing_routing
()
const
{
return
*
CHECK_JUST
(
GetBoxingDesc
(
sorted_boxing_names_
));
}
const
std
::
string
&
BoxingInterpreterStatus
::
nd_sbp_routing
()
const
{
return
*
CHECK_JUST
(
GetNdSbpRouting
(
src_placed_nd_sbp_
,
mid_placed_nd_sbp_
,
dst_placed_nd_sbp_
));
}
const
std
::
string
&
BoxingInterpreterStatus
::
placement_routing
()
const
{
return
*
CHECK_JUST
(
GetPlacementRouting
(
src_placed_nd_sbp_
,
mid_placed_nd_sbp_
,
dst_placed_nd_sbp_
));
}
}
// namespace oneflow
oneflow/core/boxing/boxing_interpreter_status.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_
#define ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/framework/placed_nd_sbp.h"
#include "oneflow/core/common/shape.h"
namespace
oneflow
{
class
BoxingInterpreterStatus
;
extern
Maybe
<
BoxingInterpreterStatus
>
(
*
MakeBoxingInterpreterStatus
)(
const
std
::
string
&
boxing_name
,
const
Shape
&
logical_shape
,
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
);
extern
Maybe
<
BoxingInterpreterStatus
>
(
*
MakeComposedBoxingInterpreterStatus
)(
const
std
::
shared_ptr
<
BoxingInterpreterStatus
>&
lhs_status
,
const
std
::
shared_ptr
<
BoxingInterpreterStatus
>&
rhs_status
);
class
BoxingInterpreterStatus
final
{
public:
BoxingInterpreterStatus
(
Symbol
<
std
::
vector
<
std
::
string
>>
sorted_boxing_names
,
const
Shape
&
logical_shape
,
Symbol
<
PlacedNdSbp
>
src_placed_nd_sbp
,
Symbol
<
std
::
vector
<
Symbol
<
PlacedNdSbp
>>>
mid_placed_nd_sbp
,
Symbol
<
PlacedNdSbp
>
dst_placed_nd_sbp
)
:
sorted_boxing_names_
(
sorted_boxing_names
),
logical_shape_
(
logical_shape
),
src_placed_nd_sbp_
(
src_placed_nd_sbp
),
mid_placed_nd_sbp_
(
mid_placed_nd_sbp
),
dst_placed_nd_sbp_
(
dst_placed_nd_sbp
)
{}
BoxingInterpreterStatus
(
Symbol
<
std
::
vector
<
std
::
string
>>
sorted_boxing_names
,
const
Shape
&
logical_shape
,
Symbol
<
PlacedNdSbp
>
src_placed_nd_sbp
,
Symbol
<
PlacedNdSbp
>
dst_placed_nd_sbp
)
:
BoxingInterpreterStatus
(
sorted_boxing_names
,
logical_shape
,
src_placed_nd_sbp
,
SymbolOf
(
std
::
vector
<
Symbol
<
PlacedNdSbp
>>
()),
dst_placed_nd_sbp
)
{}
~
BoxingInterpreterStatus
()
=
default
;
bool
operator
==
(
const
BoxingInterpreterStatus
&
other
)
const
{
return
this
->
sorted_boxing_names_
==
other
.
sorted_boxing_names_
&&
this
->
src_placed_nd_sbp_
==
other
.
src_placed_nd_sbp_
&&
this
->
mid_placed_nd_sbp_
==
other
.
mid_placed_nd_sbp_
&&
this
->
dst_placed_nd_sbp_
==
other
.
dst_placed_nd_sbp_
;
}
// Getters
Symbol
<
std
::
vector
<
std
::
string
>>
sorted_boxing_names
()
const
{
return
sorted_boxing_names_
;
}
const
Shape
&
logical_shape
()
const
{
return
logical_shape_
;
}
Symbol
<
PlacedNdSbp
>
src_placed_nd_sbp
()
const
{
return
src_placed_nd_sbp_
;
}
Symbol
<
PlacedNdSbp
>
dst_placed_nd_sbp
()
const
{
return
dst_placed_nd_sbp_
;
}
Symbol
<
std
::
vector
<
Symbol
<
PlacedNdSbp
>>>
mid_placed_nd_sbp
()
const
{
return
mid_placed_nd_sbp_
;
}
const
std
::
string
&
boxing_routing
()
const
;
const
std
::
string
&
nd_sbp_routing
()
const
;
const
std
::
string
&
placement_routing
()
const
;
private:
Symbol
<
std
::
vector
<
std
::
string
>>
sorted_boxing_names_
;
const
Shape
logical_shape_
;
Symbol
<
PlacedNdSbp
>
src_placed_nd_sbp_
;
Symbol
<
std
::
vector
<
Symbol
<
PlacedNdSbp
>>>
mid_placed_nd_sbp_
;
Symbol
<
PlacedNdSbp
>
dst_placed_nd_sbp_
;
};
}
// namespace oneflow
namespace
std
{
template
<
>
struct
hash
<
oneflow
::
BoxingInterpreterStatus
>
{
size_t
operator
()(
const
oneflow
::
BoxingInterpreterStatus
&
status
)
const
{
size_t
ret
=
0
;
for
(
const
auto
&
boxing_name
:
*
status
.
sorted_boxing_names
())
{
ret
^=
std
::
hash
<
string
>
()(
boxing_name
);
}
const
auto
&
placed_nd_sbp_hash
=
std
::
hash
<
oneflow
::
PlacedNdSbp
>
();
ret
^=
placed_nd_sbp_hash
(
*
status
.
src_placed_nd_sbp
());
for
(
const
auto
&
mid_placed_nd_sbp
:
*
status
.
mid_placed_nd_sbp
())
{
ret
^=
placed_nd_sbp_hash
(
*
mid_placed_nd_sbp
);
}
ret
^=
placed_nd_sbp_hash
(
*
status
.
dst_placed_nd_sbp
());
return
hash
<
size_t
>
()(
ret
);
}
};
}
// namespace std
#endif // ONEFLOW_CORE_BOXING_BOXING_INTERPRETER_STATUS_H_
oneflow/core/boxing/ccl_boxing_function.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
{
Maybe
<
void
>
RawCheckCclP2B
(
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
,
const
Shape
&
logical_shape
)
{
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
in
->
nd_sbp
()
->
sbp_parallel_size
(),
1
);
CHECK_EQ_OR_RETURN
(
out
->
nd_sbp
()
->
sbp_parallel_size
(),
1
);
CHECK_OR_RETURN
(
NdSbpIsAllPartialSum
(
*
in
->
nd_sbp
()));
CHECK_OR_RETURN
(
NdSbpIsAllBroadcast
(
*
out
->
nd_sbp
()));
CHECK_OR_RETURN
(
in
->
placement
()
==
out
->
placement
());
CHECK_OR_RETURN
(
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCPU
||
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCUDA
);
// NOLINTEND(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
static
constexpr
auto
*
CheckCclP2B
=
DECORATE
(
&
RawCheckCclP2B
,
ThreadLocalCachedCopiable
);
Maybe
<
void
>
RawCheckCclP2S
(
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
,
const
Shape
&
logical_shape
)
{
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
in
->
nd_sbp
()
->
sbp_parallel_size
(),
1
);
CHECK_EQ_OR_RETURN
(
out
->
nd_sbp
()
->
sbp_parallel_size
(),
1
);
CHECK_OR_RETURN
(
NdSbpIsAllPartialSum
(
*
in
->
nd_sbp
()));
CHECK_OR_RETURN
(
NdSbpIsAllSplit
(
*
out
->
nd_sbp
(),
0
));
CHECK_GT_OR_RETURN
(
logical_shape
.
NumAxes
(),
0
);
CHECK_OR_RETURN
(
logical_shape
.
At
(
0
)
%
in
->
placement
()
->
parallel_num
()
==
0
);
CHECK_OR_RETURN
(
in
->
placement
()
==
out
->
placement
());
CHECK_OR_RETURN
(
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCPU
||
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCUDA
);
// NOLINTEND(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
static
constexpr
auto
*
CheckCclP2S
=
DECORATE
(
&
RawCheckCclP2S
,
ThreadLocalCachedCopiable
);
Maybe
<
void
>
RawCheckCclS2B
(
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
,
const
Shape
&
logical_shape
)
{
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
in
->
nd_sbp
()
->
sbp_parallel_size
(),
1
);
CHECK_EQ_OR_RETURN
(
out
->
nd_sbp
()
->
sbp_parallel_size
(),
1
);
CHECK_OR_RETURN
(
NdSbpIsAllSplit
(
*
in
->
nd_sbp
(),
0
));
CHECK_OR_RETURN
(
NdSbpIsAllBroadcast
(
*
out
->
nd_sbp
()));
CHECK_GT_OR_RETURN
(
logical_shape
.
NumAxes
(),
0
);
CHECK_OR_RETURN
(
logical_shape
.
At
(
0
)
%
in
->
placement
()
->
parallel_num
()
==
0
);
CHECK_OR_RETURN
(
in
->
placement
()
==
out
->
placement
());
CHECK_OR_RETURN
(
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCPU
||
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCUDA
);
// NOLINTEND(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
static
constexpr
auto
*
CheckCclS2B
=
DECORATE
(
&
RawCheckCclS2B
,
ThreadLocalCachedCopiable
);
Maybe
<
void
>
RawCheckCclS2S
(
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
,
const
Shape
&
logical_shape
)
{
// NOLINTBEGIN(maybe-need-error-msg)
CHECK_EQ_OR_RETURN
(
in
->
nd_sbp
()
->
sbp_parallel_size
(),
1
);
CHECK_EQ_OR_RETURN
(
out
->
nd_sbp
()
->
sbp_parallel_size
(),
1
);
CHECK_OR_RETURN
(
in
->
nd_sbp
()
->
sbp_parallel
(
0
).
has_split_parallel
());
CHECK_OR_RETURN
(
out
->
nd_sbp
()
->
sbp_parallel
(
0
).
has_split_parallel
());
CHECK_NE_OR_RETURN
(
in
->
nd_sbp
()
->
sbp_parallel
(
0
).
split_parallel
().
axis
(),
out
->
nd_sbp
()
->
sbp_parallel
(
0
).
split_parallel
().
axis
());
int64_t
in_split_axis
=
in
->
nd_sbp
()
->
sbp_parallel
(
0
).
split_parallel
().
axis
();
int64_t
out_split_axis
=
out
->
nd_sbp
()
->
sbp_parallel
(
0
).
split_parallel
().
axis
();
CHECK_GT_OR_RETURN
(
logical_shape
.
NumAxes
(),
in_split_axis
);
CHECK_GT_OR_RETURN
(
logical_shape
.
NumAxes
(),
out_split_axis
);
CHECK_OR_RETURN
(
logical_shape
.
At
(
in_split_axis
)
%
in
->
placement
()
->
parallel_num
()
==
0
);
CHECK_OR_RETURN
(
logical_shape
.
At
(
out_split_axis
)
%
in
->
placement
()
->
parallel_num
()
==
0
);
CHECK_OR_RETURN
(
in
->
placement
()
==
out
->
placement
());
CHECK_OR_RETURN
(
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCPU
||
in
->
placement
()
->
device_type
()
==
DeviceType
::
kCUDA
);
// NOLINTEND(maybe-need-error-msg)
return
Maybe
<
void
>::
Ok
();
}
static
constexpr
auto
*
CheckCclS2S
=
DECORATE
(
&
RawCheckCclS2S
,
ThreadLocalCachedCopiable
);
}
// namespace
Maybe
<
one
::
Tensor
>
CclP2B
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
,
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
{
const
auto
&
tensor_nd_sbp
=
JUST
(
tensor
->
nd_sbp
());
CHECK_OR_RETURN
(
tensor_nd_sbp
==
in
->
nd_sbp
())
<<
Error
::
RuntimeError
()
<<
"The sbp of input tensor ("
<<
NdSbpToString
(
tensor_nd_sbp
)
<<
") must match the input sbp ("
<<
NdSbpToString
(
in
->
nd_sbp
())
<<
")"
;
const
auto
&
tensor_placement
=
JUST
(
tensor
->
parallel_desc
());
CHECK_OR_RETURN
(
tensor_placement
==
in
->
placement
())
<<
Error
::
RuntimeError
()
<<
"The placement of input tensor ("
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
return
JUST
(
one
::
functional
::
ConsistentAllReduce
(
tensor
));
}
Maybe
<
one
::
Tensor
>
CclP2S
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
,
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
{
const
auto
&
tensor_nd_sbp
=
JUST
(
tensor
->
nd_sbp
());
CHECK_OR_RETURN
(
tensor_nd_sbp
==
in
->
nd_sbp
())
<<
Error
::
RuntimeError
()
<<
"The sbp of input tensor ("
<<
NdSbpToString
(
tensor_nd_sbp
)
<<
") must match the input sbp ("
<<
NdSbpToString
(
in
->
nd_sbp
())
<<
")"
;
const
auto
&
tensor_placement
=
JUST
(
tensor
->
parallel_desc
());
CHECK_OR_RETURN
(
tensor_placement
==
in
->
placement
())
<<
Error
::
RuntimeError
()
<<
"The placement of input tensor ("
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
return
JUST
(
one
::
functional
::
ConsistentReduceScatter
(
tensor
,
"sum"
));
}
Maybe
<
one
::
Tensor
>
CclS2B
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
,
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
{
const
auto
&
tensor_nd_sbp
=
JUST
(
tensor
->
nd_sbp
());
CHECK_OR_RETURN
(
tensor_nd_sbp
==
in
->
nd_sbp
())
<<
Error
::
RuntimeError
()
<<
"The sbp of input tensor ("
<<
NdSbpToString
(
tensor_nd_sbp
)
<<
") must match the input sbp ("
<<
NdSbpToString
(
in
->
nd_sbp
())
<<
")"
;
const
auto
&
tensor_placement
=
JUST
(
tensor
->
parallel_desc
());
CHECK_OR_RETURN
(
tensor_placement
==
in
->
placement
())
<<
Error
::
RuntimeError
()
<<
"The placement of input tensor ("
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
return
JUST
(
one
::
functional
::
ConsistentAllGather
(
tensor
));
}
Maybe
<
one
::
Tensor
>
CclS2S
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
,
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
{
const
auto
&
tensor_nd_sbp
=
JUST
(
tensor
->
nd_sbp
());
CHECK_OR_RETURN
(
tensor_nd_sbp
==
in
->
nd_sbp
())
<<
Error
::
RuntimeError
()
<<
"The sbp of input tensor ("
<<
NdSbpToString
(
tensor_nd_sbp
)
<<
") must match the input sbp ("
<<
NdSbpToString
(
in
->
nd_sbp
())
<<
")"
;
const
auto
&
tensor_placement
=
JUST
(
tensor
->
parallel_desc
());
CHECK_OR_RETURN
(
tensor_placement
==
in
->
placement
())
<<
Error
::
RuntimeError
()
<<
"The placement of input tensor ("
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
return
JUST
(
one
::
functional
::
ConsistentS2S
(
tensor
,
*
JUST
(
GetSbpList
(
out
->
nd_sbp
()))));
}
COMMAND
(
RegisterBoxingFunction
(
"ccl-p-to-b"
,
CheckCclP2B
,
&
CclP2B
));
COMMAND
(
RegisterBoxingFunction
(
"ccl-p-to-s"
,
CheckCclP2S
,
&
CclP2S
));
COMMAND
(
RegisterBoxingFunction
(
"ccl-s-to-b"
,
CheckCclS2B
,
&
CclS2B
));
COMMAND
(
RegisterBoxingFunction
(
"ccl-s-to-s"
,
CheckCclS2S
,
&
CclS2S
));
}
// namespace oneflow
oneflow/core/boxing/cuda_copy_boxing_interpreter.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/boxing/eager_boxing_interpreter.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/parallel_desc.h"
namespace
oneflow
{
namespace
{
Maybe
<
bool
>
IgnoringDeviceTypeEqual
(
Symbol
<
ParallelDesc
>
lhs
,
Symbol
<
ParallelDesc
>
rhs
)
{
return
lhs
==
JUST
(
ReplaceDeviceType
(
rhs
,
lhs
->
device_type
()));
}
}
// namespace
// NOLINTBEGIN(maybe-need-error-msg)
Maybe
<
void
>
CheckCopyH2D
(
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
,
const
Shape
&
logical_shape
)
{
bool
equal
=
JUST
(
IgnoringDeviceTypeEqual
(
in
->
placement
(),
out
->
placement
()));
CHECK_OR_RETURN
(
equal
);
CHECK_EQ_OR_RETURN
(
in
->
placement
()
->
device_type
(),
DeviceType
::
kCPU
);
CHECK_NE_OR_RETURN
(
out
->
placement
()
->
device_type
(),
DeviceType
::
kCPU
);
CHECK_OR_RETURN
(
in
->
nd_sbp
()
==
out
->
nd_sbp
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
CheckCopyD2H
(
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
,
const
Shape
&
logical_shape
)
{
bool
equal
=
JUST
(
IgnoringDeviceTypeEqual
(
in
->
placement
(),
out
->
placement
()));
CHECK_OR_RETURN
(
equal
);
CHECK_NE_OR_RETURN
(
in
->
placement
()
->
device_type
(),
DeviceType
::
kCPU
);
CHECK_EQ_OR_RETURN
(
out
->
placement
()
->
device_type
(),
DeviceType
::
kCPU
);
CHECK_OR_RETURN
(
in
->
nd_sbp
()
==
out
->
nd_sbp
());
return
Maybe
<
void
>::
Ok
();
}
// NOLINTEND(maybe-need-error-msg)
Maybe
<
one
::
Tensor
>
CopyBoxingFunction
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
,
Symbol
<
PlacedNdSbp
>
in
,
Symbol
<
PlacedNdSbp
>
out
)
{
const
auto
&
tensor_nd_sbp
=
JUST
(
tensor
->
nd_sbp
());
CHECK_OR_RETURN
(
tensor_nd_sbp
==
in
->
nd_sbp
())
<<
Error
::
RuntimeError
()
<<
"The sbp of input tensor ("
<<
NdSbpToString
(
tensor_nd_sbp
)
<<
") must match the input sbp ("
<<
NdSbpToString
(
in
->
nd_sbp
())
<<
")"
;
const
auto
&
tensor_placement
=
JUST
(
tensor
->
parallel_desc
());
CHECK_OR_RETURN
(
tensor_placement
==
in
->
placement
())
<<
Error
::
RuntimeError
()
<<
"The placement of input tensor ("
<<
*
JUST
(
PlacementToString
(
tensor_placement
))
<<
") must match the input placement ("
<<
*
JUST
(
PlacementToString
(
in
->
placement
()))
<<
")"
;
std
::
shared_ptr
<
one
::
Tensor
>
local_tensor
=
JUST
(
tensor
->
cur_rank_phy_tensor
());
const
auto
&
out_parallel_id
=
JUST
(
GetParallelId4CurrentProcessCtx
(
out
->
placement
()));
if
(
!
out_parallel_id
->
has_value
())
{
const
std
::
string
&
device_type
=
tensor_placement
->
device_tag
();
local_tensor
=
JUST
(
one
::
functional
::
Empty
(
*
JUST
(
GetPhysicalShape
(
*
tensor
->
shape
(),
*
tensor_nd_sbp
,
*
tensor_placement
,
0
)),
tensor
->
dtype
(),
JUST
(
Device
::
New
(
device_type
)),
/*pin_memory=*/
false
));
}
const
auto
&
sbp_list
=
JUST
(
GetSbpList
(
out
->
nd_sbp
()));
return
JUST
(
one
::
functional
::
LocalToConsistent
(
local_tensor
,
out
->
placement
(),
*
sbp_list
,
*
tensor
->
shape
(),
tensor
->
dtype
()));
}
COMMAND
(
RegisterBoxingFunction
(
"copy-h2d"
,
&
CheckCopyH2D
,
&
CopyBoxingFunction
));
COMMAND
(
RegisterBoxingFunction
(
"copy-d2h"
,
&
CheckCopyD2H
,
&
CopyBoxingFunction
));
}
// namespace oneflow
Prev
1
…
18
19
20
21
22
23
24
25
26
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment