Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
21d47d0e
Commit
21d47d0e
authored
Oct 24, 2022
by
yuguo
Browse files
Oneflow 0.8 for DCU
parents
Changes
556
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3572 additions
and
0 deletions
+3572
-0
oneflow/api/python/rpc/rank_group.cpp
oneflow/api/python/rpc/rank_group.cpp
+44
-0
oneflow/api/python/session/session.cpp
oneflow/api/python/session/session.cpp
+56
-0
oneflow/api/python/session/session.h
oneflow/api/python/session/session.h
+117
-0
oneflow/api/python/symbol/job_conf_symbol.cpp
oneflow/api/python/symbol/job_conf_symbol.cpp
+53
-0
oneflow/api/python/symbol/op_conf_symbol.cpp
oneflow/api/python/symbol/op_conf_symbol.cpp
+39
-0
oneflow/api/python/symbol/placement_symbol.cpp
oneflow/api/python/symbol/placement_symbol.cpp
+264
-0
oneflow/api/python/symbol/sbp_symbol.cpp
oneflow/api/python/symbol/sbp_symbol.cpp
+111
-0
oneflow/api/python/symbol/scope_symbol.cpp
oneflow/api/python/symbol/scope_symbol.cpp
+59
-0
oneflow/api/python/utils/dataloader.cpp
oneflow/api/python/utils/dataloader.cpp
+228
-0
oneflow/api/python/utils/tensor_utils.cpp
oneflow/api/python/utils/tensor_utils.cpp
+319
-0
oneflow/api/python/utils/tensor_utils.h
oneflow/api/python/utils/tensor_utils.h
+169
-0
oneflow/core/auto_parallel/boxing_collector.cpp
oneflow/core/auto_parallel/boxing_collector.cpp
+1010
-0
oneflow/core/auto_parallel/boxing_collector.h
oneflow/core/auto_parallel/boxing_collector.h
+161
-0
oneflow/core/autograd/autograd_captured_tensor.h
oneflow/core/autograd/autograd_captured_tensor.h
+64
-0
oneflow/core/autograd/autograd_engine.cpp
oneflow/core/autograd/autograd_engine.cpp
+430
-0
oneflow/core/autograd/autograd_engine.h
oneflow/core/autograd/autograd_engine.h
+171
-0
oneflow/core/autograd/autograd_function.cpp
oneflow/core/autograd/autograd_function.cpp
+43
-0
oneflow/core/autograd/autograd_function.h
oneflow/core/autograd/autograd_function.h
+43
-0
oneflow/core/autograd/autograd_meta.cpp
oneflow/core/autograd/autograd_meta.cpp
+77
-0
oneflow/core/autograd/autograd_meta.h
oneflow/core/autograd/autograd_meta.h
+114
-0
No files found.
Too many changes to show.
To preserve performance only
556 of 556+
files are displayed.
Plain diff
Email patch
oneflow/api/python/rpc/rank_group.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/rank_group_rpc_util.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/job/rank_group_scope.h"
#include "oneflow/core/common/symbol.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
{
Maybe
<
void
>
CheckCurrentRankGroupConsistency
()
{
const
auto
&
rank_group
=
JUST
(
RankGroupScope
::
CurrentRankGroup
());
const
auto
&
ctx
=
JUST
(
CheckTransportToken
(
rank_group
));
JUST
(
ctx
->
WaitDone
());
return
Maybe
<
void
>::
Ok
();
}
}
// namespace
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
m
.
def
(
"check_current_rank_group_consistency"
,
&
CheckCurrentRankGroupConsistency
);
}
}
// namespace oneflow
oneflow/api/python/session/session.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <string>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/job/session.h"
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/api/python/session/session.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
m
.
def
(
"IsSessionInited"
,
&
IsSessionInited
);
m
.
def
(
"InitLazyGlobalSession"
,
&
InitLazyGlobalSession
);
m
.
def
(
"InitEagerGlobalSession"
,
&
InitEagerGlobalSession
);
m
.
def
(
"DestroyLazyGlobalSession"
,
&
DestroyLazyGlobalSession
);
m
.
def
(
"StartLazyGlobalSession"
,
&
StartLazyGlobalSession
);
m
.
def
(
"StopLazyGlobalSession"
,
&
StopLazyGlobalSession
);
using
namespace
oneflow
;
py
::
class_
<
MultiClientSessionContext
,
std
::
shared_ptr
<
MultiClientSessionContext
>>
(
m
,
"SessionContext"
)
.
def
(
py
::
init
<
const
std
::
shared_ptr
<
EnvGlobalObjectsScope
>&>
())
.
def
(
"try_init"
,
[](
MultiClientSessionContext
&
session
,
const
std
::
string
&
config_proto_str
)
{
return
session
.
TryInit
(
config_proto_str
).
GetOrThrow
();
})
.
def
(
"update_resource"
,
[](
MultiClientSessionContext
&
session
,
const
std
::
string
&
reso_proto_str
)
{
return
session
.
UpdateResource
(
reso_proto_str
).
GetOrThrow
();
});
m
.
def
(
"NewSessionId"
,
&
NewSessionId
);
py
::
class_
<
LogicalConfigProtoContext
>
(
m
,
"LogicalConfigProtoContext"
)
.
def
(
py
::
init
<
const
std
::
string
&>
());
}
}
// namespace oneflow
oneflow/api/python/session/session.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_SESSION_SESSION_H_
#define ONEFLOW_API_PYTHON_SESSION_SESSION_H_
#include <string>
#include <google/protobuf/text_format.h>
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/core/job/session_global_objects_scope.h"
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/job/oneflow.h"
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/core/framework/nn_graph.h"
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
namespace
oneflow
{
inline
Maybe
<
bool
>
IsSessionInited
()
{
return
Singleton
<
SessionGlobalObjectsScope
>::
Get
()
!=
nullptr
;
}
inline
void
FixCpuDeviceNum
(
ConfigProto
*
config_proto
)
{
if
(
config_proto
->
resource
().
cpu_device_num
()
>
0
)
{
return
;
}
config_proto
->
mutable_resource
()
->
set_cpu_device_num
(
std
::
thread
::
hardware_concurrency
());
}
inline
Maybe
<
void
>
InitEagerGlobalSession
(
const
std
::
string
&
config_proto_str
)
{
CHECK_NOTNULL_OR_RETURN
(
Singleton
<
EnvDesc
>::
Get
())
<<
"env not found"
;
ConfigProto
config_proto
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
config_proto_str
,
&
config_proto
))
<<
"failed to parse config_proto: "
<<
config_proto_str
;
FixCpuDeviceNum
(
&
config_proto
);
Singleton
<
CtrlClient
>::
Get
()
->
PushKV
(
"config_proto"
,
config_proto
);
CHECK_ISNULL_OR_RETURN
(
Singleton
<
SessionGlobalObjectsScope
>::
Get
());
Singleton
<
SessionGlobalObjectsScope
>::
SetAllocated
(
new
SessionGlobalObjectsScope
());
JUST
(
Singleton
<
SessionGlobalObjectsScope
>::
Get
()
->
EagerInit
(
config_proto
));
VLOG
(
3
)
<<
"NewGlobal "
<<
typeid
(
SessionGlobalObjectsScope
).
name
();
return
Maybe
<
void
>::
Ok
();
}
inline
Maybe
<
void
>
InitLazyGlobalSession
(
const
std
::
string
&
config_proto_str
)
{
CHECK_NOTNULL_OR_RETURN
(
Singleton
<
EnvDesc
>::
Get
())
<<
"env not found"
;
CHECK_OR_RETURN
(
GlobalProcessCtx
::
IsThisProcessMaster
());
ClusterInstruction
::
MasterSendSessionStart
();
ConfigProto
config_proto
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
config_proto_str
,
&
config_proto
))
<<
"failed to parse config_proto: "
<<
config_proto_str
;
FixCpuDeviceNum
(
&
config_proto
);
Singleton
<
CtrlClient
>::
Get
()
->
PushKV
(
"config_proto"
,
config_proto
);
CHECK_ISNULL_OR_RETURN
(
Singleton
<
SessionGlobalObjectsScope
>::
Get
());
Singleton
<
SessionGlobalObjectsScope
>::
SetAllocated
(
new
SessionGlobalObjectsScope
());
JUST
(
Singleton
<
SessionGlobalObjectsScope
>::
Get
()
->
Init
(
config_proto
));
VLOG
(
3
)
<<
"NewGlobal "
<<
typeid
(
SessionGlobalObjectsScope
).
name
();
return
Maybe
<
void
>::
Ok
();
}
inline
Maybe
<
void
>
DestroyLazyGlobalSession
()
{
if
(
Singleton
<
SessionGlobalObjectsScope
>::
Get
()
==
nullptr
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_OR_RETURN
(
GlobalProcessCtx
::
IsThisProcessMaster
());
Singleton
<
SessionGlobalObjectsScope
>::
Delete
();
return
Maybe
<
void
>::
Ok
();
}
inline
Maybe
<
void
>
StartLazyGlobalSession
()
{
CHECK_NOTNULL_OR_RETURN
(
Singleton
<
SessionGlobalObjectsScope
>::
Get
())
<<
"session not found"
;
CHECK_OR_RETURN
(
GlobalProcessCtx
::
IsThisProcessMaster
());
const
JobSet
&
job_set
=
Singleton
<
LazyJobBuildAndInferCtxMgr
>::
Get
()
->
job_set
();
if
(
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
->
enable_debug_mode
())
{
TeePersistentLogStream
::
Create
(
"job_set.prototxt"
)
->
Write
(
job_set
);
}
if
(
job_set
.
job
().
empty
())
{
return
Error
::
JobSetEmptyError
()
<<
"no function defined"
;
}
CHECK_ISNULL_OR_RETURN
(
Singleton
<
Oneflow
>::
Get
());
Singleton
<
CtrlClient
>::
Get
()
->
PushKV
(
"session_job_set"
,
job_set
);
Singleton
<
const
InterJobReuseMemStrategy
>::
New
(
job_set
.
inter_job_reuse_mem_strategy
());
Singleton
<
Oneflow
>::
New
();
JUST
(
Singleton
<
Oneflow
>::
Get
()
->
Init
(
job_set
));
return
Maybe
<
void
>::
Ok
();
}
inline
Maybe
<
void
>
StopLazyGlobalSession
()
{
if
(
Singleton
<
Oneflow
>::
Get
()
==
nullptr
)
{
return
Maybe
<
void
>::
Ok
();
}
CHECK_OR_RETURN
(
GlobalProcessCtx
::
IsThisProcessMaster
());
CHECK_NOTNULL_OR_RETURN
(
Singleton
<
Oneflow
>::
Get
());
Singleton
<
Oneflow
>::
Delete
();
Singleton
<
const
InterJobReuseMemStrategy
>::
Delete
();
return
Maybe
<
void
>::
Ok
();
}
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_SESSION_SESSION_H_
oneflow/api/python/symbol/job_conf_symbol.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/core/common/throw.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/job_conf.pb.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
Maybe
<
JobDesc
>
CreateJobConfSymbol
(
int64_t
symbol_id
,
const
std
::
string
&
serialized_symbol_conf
)
{
JobConfigProto
symbol_pb
;
if
(
!
TxtString2PbMessage
(
serialized_symbol_conf
,
&
symbol_pb
))
{
THROW
(
RuntimeError
)
<<
"job conf parse failed.
\n
"
<<
serialized_symbol_conf
;
}
return
JobDesc
::
New
(
symbol_id
,
symbol_pb
);
}
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
py
::
class_
<
JobDesc
,
std
::
shared_ptr
<
JobDesc
>>
(
m
,
"JobConfSymbol"
)
.
def
(
py
::
init
([](
int64_t
symbol_id
,
const
std
::
string
&
serialized_symbol_conf
)
{
return
CreateJobConfSymbol
(
symbol_id
,
serialized_symbol_conf
).
GetPtrOrThrow
();
}))
.
def_property_readonly
(
"symbol_id"
,
[](
const
JobDesc
&
x
)
{
if
(
!
x
.
symbol_id
().
has_value
())
{
THROW
(
RuntimeError
)
<<
"symbol_id not initialized"
;
}
return
CHECK_JUST
(
x
.
symbol_id
());
})
.
def_property_readonly
(
"data"
,
[](
const
JobDesc
&
job_conf_sym
)
->
std
::
string
{
return
PbMessage2TxtString
(
job_conf_sym
.
job_conf
());
});
}
}
// namespace oneflow
oneflow/api/python/symbol/op_conf_symbol.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/core/common/throw.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/operator/op_conf_symbol.h"
#include "oneflow/core/common/maybe.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
py
::
class_
<
OperatorConfSymbol
,
std
::
shared_ptr
<
OperatorConfSymbol
>>
(
m
,
"OpConfSymbol"
)
.
def_property_readonly
(
"symbol_id"
,
[](
const
OperatorConfSymbol
&
x
)
{
if
(
!
x
.
symbol_id
().
has_value
())
{
THROW
(
RuntimeError
)
<<
"symbol_id not initialized"
;
}
return
CHECK_JUST
(
x
.
symbol_id
());
})
.
def_property_readonly
(
"data"
,
&
OperatorConfSymbol
::
data
);
}
}
// namespace oneflow
oneflow/api/python/symbol/placement_symbol.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <pybind11/operators.h>
#include "oneflow/core/common/maybe.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/parallel_conf_util.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
{
int64_t
GetDeviceCount
(
const
std
::
string
&
device_name
)
{
return
Singleton
<
ep
::
DeviceManagerRegistry
>::
Get
()
->
GetDeviceCount
(
device_name
);
}
struct
PlacementSymbolExportUtil
{
static
Maybe
<
void
>
CheckDeviceTag
(
const
std
::
string
&
type
)
{
if
(
!
TRY
(
DeviceType4DeviceTag
(
type
)).
IsOk
())
{
return
Error
::
RuntimeError
()
<<
"Expected one of "
<<
PrintAvailableDevices
()
<<
" device type at start of device string: "
<<
type
;
}
return
Maybe
<
void
>::
Ok
();
}
static
Maybe
<
ParallelDesc
>
CreateParallelDesc
(
const
std
::
string
&
type
,
const
std
::
vector
<
std
::
string
>&
formated_machine_device_ids
,
const
std
::
shared_ptr
<
Shape
>&
hierarchy_shape
)
{
JUST
(
CheckDeviceTag
(
type
));
auto
parallel_conf
=
JUST
(
MakeParallelConf
(
type
,
formated_machine_device_ids
,
hierarchy_shape
));
std
::
shared_ptr
<
ParallelDesc
>
parallel_desc
;
JUST
(
PhysicalRun
([
&
parallel_desc
,
&
parallel_conf
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
parallel_desc
=
JUST
(
builder
->
GetParallelDescSymbol
(
*
parallel_conf
));
return
Maybe
<
void
>::
Ok
();
}));
return
parallel_desc
;
}
static
Maybe
<
ParallelDesc
>
CreateParallelDesc
(
const
std
::
string
&
proto_str
)
{
ParallelConf
parallel_conf
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
proto_str
,
&
parallel_conf
))
<<
" Get ParallelConf Pb from string failed."
;
std
::
shared_ptr
<
ParallelDesc
>
parallel_desc
;
JUST
(
PhysicalRun
([
&
parallel_desc
,
&
parallel_conf
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
parallel_desc
=
JUST
(
builder
->
GetParallelDescSymbol
(
parallel_conf
));
return
Maybe
<
void
>::
Ok
();
}));
return
parallel_desc
;
}
static
Maybe
<
std
::
vector
<
std
::
string
>>
ParseAndFormatRanks
(
const
py
::
dict
&
device_ids
)
{
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>>
machine_device_id_vec
;
for
(
const
auto
&
pair
:
device_ids
)
{
CHECK_OR_RETURN
(
py
::
isinstance
<
py
::
int_
>
(
pair
.
first
))
<<
"The key (node id) of placement device_ids must be int64."
;
int64_t
machine_id
=
pair
.
first
.
cast
<
int64_t
>
();
if
(
py
::
isinstance
<
py
::
int_
>
(
pair
.
second
))
{
machine_device_id_vec
.
emplace_back
(
machine_id
,
pair
.
second
.
cast
<
int64_t
>
());
}
else
{
CHECK_OR_RETURN
(
py
::
isinstance
<
py
::
iterable
>
(
pair
.
second
))
<<
"Value of device_ids dict must be int, list or range"
;
for
(
const
auto
&
device_id
:
pair
.
second
)
{
CHECK_OR_RETURN
(
py
::
isinstance
<
py
::
int_
>
(
device_id
))
<<
"Value of device_ids dict must be int, list or range of int."
;
machine_device_id_vec
.
emplace_back
(
machine_id
,
device_id
.
cast
<
int64_t
>
());
}
}
}
auto
formated_machine_device_ids
=
std
::
make_shared
<
std
::
vector
<
std
::
string
>>
();
for
(
const
auto
&
pair
:
machine_device_id_vec
)
{
const
std
::
string
&
device_name
=
std
::
to_string
(
pair
.
first
)
+
":"
+
std
::
to_string
(
pair
.
second
);
formated_machine_device_ids
->
emplace_back
(
device_name
);
}
return
formated_machine_device_ids
;
}
static
Maybe
<
Shape
>
GetRanksShape
(
PyArrayObject
*
ranks
)
{
auto
*
shape
=
PyArray_SHAPE
(
ranks
);
return
std
::
make_shared
<
Shape
>
(
DimVector
(
shape
,
shape
+
PyArray_NDIM
(
ranks
)));
}
// Parse and format ranks to string "machine_id:local_rank"
static
Maybe
<
std
::
vector
<
std
::
string
>>
ParseAndFormatRanks
(
PyArrayObject
*
ranks
)
{
size_t
size
=
PyArray_SIZE
(
ranks
);
CHECK_EQ_OR_RETURN
(
PyArray_TYPE
(
ranks
),
NPY_INT64
)
<<
Error
::
RuntimeError
()
<<
"placement ranks shoule be an array of long int"
;
int64_t
*
rank_data
=
static_cast
<
int64_t
*>
(
PyArray_DATA
(
ranks
));
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>>
machine_device_id_vec
;
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
int64_t
rank
=
rank_data
[
i
];
int64_t
machine_id
=
GlobalProcessCtx
::
NodeId
(
rank
);
int64_t
device_id
=
GlobalProcessCtx
::
LocalRank
(
rank
);
machine_device_id_vec
.
emplace_back
(
machine_id
,
device_id
);
}
auto
formated_machine_device_ids
=
std
::
make_shared
<
std
::
vector
<
std
::
string
>>
();
for
(
const
auto
&
pair
:
machine_device_id_vec
)
{
auto
device_name
=
std
::
to_string
(
pair
.
first
)
+
":"
+
std
::
to_string
(
pair
.
second
);
formated_machine_device_ids
->
emplace_back
(
device_name
);
}
return
formated_machine_device_ids
;
}
static
Maybe
<
Symbol
<
ParallelDesc
>>
CreateParallelDescSymbol
(
const
std
::
string
&
type
,
const
py
::
dict
&
device_ids
,
const
std
::
shared_ptr
<
Shape
>&
hierarchy
)
{
const
auto
&
formated_machine_device_ids
=
JUST
(
ParseAndFormatRanks
(
device_ids
));
return
SymbolOf
(
*
JUST
(
CreateParallelDesc
(
type
,
*
formated_machine_device_ids
,
hierarchy
)));
}
// create Symbol<ParallelDesc> object through given device_type and ranks parameters
static
Maybe
<
Symbol
<
ParallelDesc
>>
CreateParallelDescSymbol
(
const
std
::
string
&
type
,
const
py
::
object
&
ranks
)
{
auto
*
obj
=
reinterpret_cast
<
PyArrayObject
*>
(
PyArray_FromAny
(
ranks
.
ptr
(),
nullptr
,
0
,
0
,
NPY_ARRAY_DEFAULT
|
NPY_ARRAY_ENSURECOPY
,
nullptr
));
if
(
!
obj
)
{
return
Error
::
RuntimeError
()
<<
"placement ranks shoule be an array of long int"
;
}
const
auto
&
shape
=
JUST
(
GetRanksShape
(
obj
));
const
auto
&
formated_machine_device_ids
=
JUST
(
ParseAndFormatRanks
(
obj
));
return
SymbolOf
(
*
JUST
(
CreateParallelDesc
(
type
,
*
formated_machine_device_ids
,
shape
)));
}
static
Maybe
<
Symbol
<
ParallelDesc
>>
CreateParallelDescSymbol
(
const
std
::
string
&
proto_str
)
{
return
SymbolOf
(
*
JUST
(
CreateParallelDesc
(
proto_str
)));
}
static
Maybe
<
Symbol
<
ParallelDesc
>>
AllDevicePlacement
(
const
std
::
string
&
type
)
{
static
thread_local
HashMap
<
std
::
string
,
Symbol
<
ParallelDesc
>>
device_tag2placement
;
CHECK_NOTNULL
((
Singleton
<
ResourceDesc
,
ForEnv
>::
Get
()));
JUST
(
CheckDeviceTag
(
type
));
auto
it
=
device_tag2placement
.
find
(
type
);
if
(
it
==
device_tag2placement
.
end
())
{
int64_t
node_size
=
GlobalProcessCtx
::
NodeSize
();
int64_t
device_num
=
GlobalProcessCtx
::
NumOfProcessPerNode
();
if
(
type
!=
"cpu"
)
{
const
int64_t
device_count
=
GetDeviceCount
(
type
);
CHECK_NE_OR_RETURN
(
device_count
,
0
)
<<
Error
::
RuntimeError
()
<<
"Can
\'
t construct placement with
\"
"
<<
type
<<
"
\"
type because there is no device!"
;
device_num
=
std
::
min
(
device_num
,
device_count
);
}
std
::
vector
<
std
::
string
>
machine_device_ids
;
for
(
int64_t
node_id
=
0
;
node_id
<
node_size
;
++
node_id
)
{
std
::
string
device_name
=
std
::
to_string
(
node_id
)
+
":0-"
+
std
::
to_string
(
device_num
-
1
);
machine_device_ids
.
emplace_back
(
device_name
);
}
Symbol
<
ParallelDesc
>
placement
=
SymbolOf
(
*
JUST
(
CreateParallelDesc
(
type
,
machine_device_ids
,
std
::
shared_ptr
<
Shape
>
())));
it
=
device_tag2placement
.
emplace
(
type
,
placement
).
first
;
}
return
it
->
second
;
}
static
Maybe
<
py
::
array
>
GetPlacementRanks
(
const
Symbol
<
ParallelDesc
>&
placement
)
{
py
::
list
ranks
;
for
(
int64_t
machine_id
:
placement
->
sorted_machine_ids
())
{
int64_t
node_id
=
GlobalProcessCtx
::
NodeId
(
machine_id
);
for
(
int64_t
device_id
:
placement
->
sorted_dev_phy_ids
(
machine_id
))
{
ranks
.
append
(
py
::
cast
(
node_id
*
GlobalProcessCtx
::
NumOfProcessPerNode
()
+
device_id
));
}
}
auto
array_ranks
=
py
::
cast
<
py
::
array
>
(
ranks
);
array_ranks
.
resize
(
placement
->
hierarchy
()
->
dim_vec
());
return
array_ranks
;
}
};
}
// namespace
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
py
::
class_
<
Symbol
<
ParallelDesc
>
,
std
::
shared_ptr
<
Symbol
<
ParallelDesc
>>>
(
m
,
"placement"
,
py
::
dynamic_attr
())
.
def
(
py
::
init
([](
const
std
::
string
&
device_type
,
const
py
::
dict
&
device_ids
,
const
std
::
shared_ptr
<
Shape
>&
hierarchy
)
{
PyErr_WarnEx
(
PyExc_UserWarning
,
"The way to construct placement is deprecated, and it will be removed in next "
"versions. Please use oneflow.placement(type=str, ranks=int array) instead"
,
1
);
return
PlacementSymbolExportUtil
::
CreateParallelDescSymbol
(
device_type
,
device_ids
,
hierarchy
)
.
GetOrThrow
();
}),
py
::
arg
(
"device_type"
),
py
::
arg
(
"device_ids"
),
py
::
arg
(
"hierarchy"
))
.
def
(
py
::
init
([](
const
std
::
string
&
device_type
,
const
py
::
dict
&
device_ids
,
const
py
::
tuple
&
hierarchy
)
{
PyErr_WarnEx
(
PyExc_UserWarning
,
"The way to construct placement is deprecated, and it will be removed in next "
"versions. Please use oneflow.placement(type=str, ranks=int array) instead"
,
1
);
DimVector
shape_dims
{};
for
(
const
auto
&
dim
:
hierarchy
)
{
shape_dims
.
emplace_back
(
dim
.
cast
<
int64_t
>
());
}
return
PlacementSymbolExportUtil
::
CreateParallelDescSymbol
(
device_type
,
device_ids
,
std
::
make_shared
<
Shape
>
(
shape_dims
))
.
GetOrThrow
();
}),
py
::
arg
(
"device_type"
),
py
::
arg
(
"device_ids"
),
py
::
arg
(
"hierarchy"
)
=
py
::
tuple
())
.
def
(
py
::
init
([](
const
std
::
string
&
type
,
const
py
::
object
&
ranks
)
{
return
PlacementSymbolExportUtil
::
CreateParallelDescSymbol
(
type
,
ranks
).
GetOrThrow
();
}),
py
::
arg
(
"type"
),
py
::
arg
(
"ranks"
))
.
def
(
py
::
init
([](
const
std
::
string
&
proto_str
)
{
return
PlacementSymbolExportUtil
::
CreateParallelDescSymbol
(
proto_str
).
GetOrThrow
();
}),
py
::
arg
(
"proto_str"
))
.
def_property_readonly
(
"device_type"
,
[](
Symbol
<
ParallelDesc
>
p
)
{
PyErr_WarnEx
(
PyExc_UserWarning
,
"The property .device_type of placement is deprecated, please use .type instead"
,
1
);
return
p
->
device_tag
();
})
.
def_property_readonly
(
"type"
,
[](
Symbol
<
ParallelDesc
>
p
)
{
return
p
->
device_tag
();
})
.
def_property_readonly
(
"hierarchy"
,
[](
Symbol
<
ParallelDesc
>
p
)
{
PyErr_WarnEx
(
PyExc_UserWarning
,
"The property .hierarchy of placement is deprecated, "
"please use .ranks.shape instead"
,
1
);
return
p
->
hierarchy
();
})
.
def_property_readonly
(
"ranks"
,
&
PlacementSymbolExportUtil
::
GetPlacementRanks
)
.
def
(
"__str__"
,
PlacementToString
)
.
def
(
"__repr__"
,
PlacementToString
)
.
def
(
py
::
self
==
py
::
self
)
.
def
(
py
::
hash
(
py
::
self
));
m
.
def
(
"AllDevicePlacement"
,
&
PlacementSymbolExportUtil
::
AllDevicePlacement
);
}
}
// namespace oneflow
oneflow/api/python/symbol/sbp_symbol.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/common/sbp.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/constant.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/framework/nd_sbp.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
{
Maybe
<
std
::
vector
<
Symbol
<
SbpParallel
>>>
MakeSplitSbpParallelList
(
int
max_split_axis
)
{
std
::
shared_ptr
<
std
::
vector
<
Symbol
<
SbpParallel
>>>
ret
=
std
::
make_shared
<
std
::
vector
<
Symbol
<
SbpParallel
>>>
(
max_split_axis
);
for
(
int
i
=
0
;
i
<
max_split_axis
;
++
i
)
{
ret
->
at
(
i
)
=
JUST
(
MakeSplitSbpParallel
(
i
));
}
return
ret
;
}
Maybe
<
Symbol
<
SbpParallel
>>
GetSplitSbpParallel
(
int
axis
)
{
CHECK_GE_OR_RETURN
(
axis
,
0
)
<<
Error
::
RuntimeError
()
<<
"Split axis must not be negative, but got "
<<
axis
<<
"!"
;
CHECK_LT_OR_RETURN
(
axis
,
kMaxSplitAxis
)
<<
Error
::
RuntimeError
()
<<
"Expected split axis to be less than the supported maximum axis ("
<<
kMaxSplitAxis
<<
"), but got "
<<
axis
<<
"!"
;
static
std
::
vector
<
Symbol
<
SbpParallel
>>
split_sbp_sym_list
=
*
JUST
(
MakeSplitSbpParallelList
(
kMaxSplitAxis
));
return
split_sbp_sym_list
.
at
(
axis
);
}
Maybe
<
Symbol
<
SbpParallel
>>
GetBroadcastSbpParallel
()
{
static
Symbol
<
SbpParallel
>
broadcast_sbp
=
JUST
(
MakeBroadcastSbpParallel
());
return
broadcast_sbp
;
}
Maybe
<
Symbol
<
SbpParallel
>>
GetPartialSumSbpParallel
()
{
static
Symbol
<
SbpParallel
>
partial_sum_sbp
=
JUST
(
MakePartialSumSbpParallel
());
return
partial_sum_sbp
;
}
Maybe
<
std
::
pair
<
std
::
string
,
int
>>
SbpGetState
(
const
Symbol
<
SbpParallel
>&
sbp
)
{
if
(
sbp
->
has_broadcast_parallel
())
{
return
std
::
make_shared
<
std
::
pair
<
std
::
string
,
int
>>
(
"B"
,
-
1
);
}
else
if
(
sbp
->
has_partial_sum_parallel
())
{
return
std
::
make_shared
<
std
::
pair
<
std
::
string
,
int
>>
(
"P"
,
-
1
);
}
else
if
(
sbp
->
has_split_parallel
())
{
return
std
::
make_shared
<
std
::
pair
<
std
::
string
,
int
>>
(
"S"
,
sbp
->
split_parallel
().
axis
());
}
else
{
return
Error
::
RuntimeError
()
<<
"Invalid sbp signature: "
<<
sbp
->
DebugString
();
}
}
Maybe
<
Symbol
<
SbpParallel
>>
GetSbpFromState
(
const
std
::
pair
<
std
::
string
,
int
>&
state
)
{
if
(
state
.
first
==
"B"
)
{
return
GetBroadcastSbpParallel
();
}
else
if
(
state
.
first
==
"P"
)
{
return
GetPartialSumSbpParallel
();
}
else
if
(
state
.
first
==
"S"
)
{
return
GetSplitSbpParallel
(
state
.
second
);
}
else
{
return
Error
::
RuntimeError
()
<<
"Invalid sbp signature state: ("
<<
state
.
first
<<
", "
<<
state
.
second
<<
");"
;
}
}
}
// namespace
ONEFLOW_API_PYBIND11_MODULE
(
"sbp"
,
m
)
{
m
.
attr
(
"max_split_axis"
)
=
kMaxSplitAxis
;
py
::
class_
<
Symbol
<
SbpParallel
>
,
std
::
shared_ptr
<
Symbol
<
SbpParallel
>>>
(
m
,
"sbp"
,
py
::
dynamic_attr
())
.
def
(
"__str__"
,
&
api
::
SbpToString
)
.
def
(
"__repr__"
,
&
api
::
SbpToString
)
.
def
(
py
::
self
==
py
::
self
)
.
def
(
py
::
hash
(
py
::
self
))
.
def
(
"_ToAttrStr"
,
[](
const
Symbol
<
SbpParallel
>&
sbp_sym
)
{
return
SbpParallelToString
(
*
sbp_sym
);
})
.
def
(
py
::
pickle
(
[](
const
Symbol
<
SbpParallel
>&
sbp
)
{
// __getstate__
return
SbpGetState
(
sbp
).
GetOrThrow
();
},
[](
const
std
::
pair
<
std
::
string
,
int
>&
state
)
{
// __setstate__
return
GetSbpFromState
(
state
).
GetOrThrow
();
}));
m
.
def
(
"split"
,
GetSplitSbpParallel
,
py
::
arg
(
"axis"
));
m
.
def
(
"broadcast"
,
&
GetBroadcastSbpParallel
);
m
.
def
(
"partial_sum"
,
&
GetPartialSumSbpParallel
);
}
}
// namespace oneflow
oneflow/api/python/symbol/scope_symbol.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/core/common/throw.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/scope.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
Maybe
<
Scope
>
CreateScopeSymbol
(
int64_t
symbol_id
,
const
std
::
string
&
symbol_conf_str
)
{
ScopeProto
symbol_pb
;
if
(
!
TxtString2PbMessage
(
symbol_conf_str
,
&
symbol_pb
))
{
THROW
(
RuntimeError
)
<<
"symbol conf parse failed.
\n
"
<<
symbol_conf_str
;
}
return
Scope
::
New
(
symbol_id
,
symbol_pb
);
}
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
py
::
class_
<
Scope
,
std
::
shared_ptr
<
Scope
>>
(
m
,
"ScopeSymbol"
)
.
def
(
py
::
init
([](
int64_t
symbol_id
,
const
std
::
string
&
symbol_conf_str
)
{
return
CreateScopeSymbol
(
symbol_id
,
symbol_conf_str
).
GetPtrOrThrow
();
}))
.
def_property_readonly
(
"symbol_id"
,
[](
const
Scope
&
x
)
{
if
(
!
x
.
symbol_id
().
has_value
())
{
THROW
(
RuntimeError
)
<<
"symbol_id not initialized"
;
}
return
CHECK_JUST
(
x
.
symbol_id
());
})
.
def_property_readonly
(
"_proto_str"
,
[](
const
Scope
&
x
)
{
return
PbMessage2TxtString
(
x
.
scope_proto
());
})
.
def
(
"auto_increment_id"
,
&
Scope
::
auto_increment_id
)
.
def_property_readonly
(
"session_id"
,
&
Scope
::
session_id
)
.
def_property_readonly
(
"job_desc_symbol"
,
&
Scope
::
job_desc_symbol
)
.
def_property_readonly
(
"device_parallel_desc_symbol"
,
[](
const
Scope
&
x
)
{
return
x
.
device_parallel_desc_symbol
().
shared_from_symbol
();
})
.
def_property_readonly
(
"parent_scope_symbol"
,
&
Scope
::
parent_scope_symbol
)
.
def
(
"MakeChildScopeProto"
,
&
Scope
::
MakeChildScopeProto
);
}
}
// namespace oneflow
oneflow/api/python/utils/dataloader.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef _WIN32
#include <atomic>
#include <map>
#include <set>
#include <csignal>
#include <sstream>
#include <sys/wait.h>
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include <stdexcept>
namespace
oneflow
{
namespace
py
=
pybind11
;
// reference: pytorch/torch/csrc/DataLoader.cpp
// https://github.com/pytorch/pytorch/blob/d69c22dd61a2f006dcfe1e3ea8468a3ecaf931aa/torch/csrc/DataLoader.cpp
// Critical signal handlers should be registered on worker processes before
// doing work.
// The handler will raise default handler so that the kill information will be
// retrieved from main process.
// Python handle is _set_worker_signal_handlers().
#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \
static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \
auto _w = write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
(void)_w; \
struct sigaction sa {}; \
sa.sa_handler = SIG_DFL; \
sa.sa_flags = 0; \
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGNAL, &sa, nullptr) != 0) { \
_exit(EXIT_FAILURE); \
} else { \
raise(SIGNAL); \
} \
}
// signal(2) is really not portable. So use sigaction.
// http://man7.org/linux/man-pages/man2/signal.2.html
static
inline
void
setSignalHandler
(
int
signal
,
void
(
*
handler
)(
int
,
siginfo_t
*
,
void
*
),
struct
sigaction
*
old_sa_ptr
)
{
struct
sigaction
sa
{};
sa
.
sa_sigaction
=
handler
;
sa
.
sa_flags
=
SA_RESTART
|
SA_SIGINFO
|
SA_NOCLDSTOP
|
SA_NODEFER
;
if
(
sigemptyset
(
&
sa
.
sa_mask
)
!=
0
||
sigaction
(
signal
,
&
sa
,
old_sa_ptr
)
!=
0
)
{
std
::
ostringstream
oss
;
oss
<<
"An error occurred while setting handler for "
<<
strsignal
(
signal
)
<<
"."
;
throw
std
::
runtime_error
(
oss
.
str
());
}
}
SIGNAL_HANDLER
(
SIGBUS
,
handler_SIGBUS
,
"ERROR: Unexpected bus error encountered in worker. "
"This might be caused by insufficient shared memory (shm).
\n
"
);
SIGNAL_HANDLER
(
SIGSEGV
,
handler_SIGSEGV
,
"ERROR: Unexpected segmentation fault encountered in worker.
\n
"
);
SIGNAL_HANDLER
(
SIGFPE
,
handler_SIGFPE
,
"ERROR: Unexpected floating-point exception encountered in worker.
\n
"
);
// When an error happened in DataLoader methods and Python starts to exit, the
// error trace will keep the loader alive, and Python may kill the children
// processes first before deleting the loader object. Then the cleaning up
// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
// loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
// exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
// again, and then it defeats the whole purpose.
static
void
handler_SIGTERM
(
int
sig
,
siginfo_t
*
info
,
void
*
ctx
)
{
if
(
info
->
si_pid
==
getppid
())
{
_exit
(
EXIT_SUCCESS
);
}
struct
sigaction
sa
{};
sa
.
sa_handler
=
SIG_DFL
;
sa
.
sa_flags
=
0
;
if
(
sigemptyset
(
&
sa
.
sa_mask
)
!=
0
||
sigaction
(
SIGTERM
,
&
sa
,
nullptr
)
!=
0
)
{
_exit
(
EXIT_FAILURE
);
}
else
{
raise
(
SIGTERM
);
}
}
static
void
set_worker_signal_handlers
()
{
setSignalHandler
(
SIGBUS
,
&
handler_SIGBUS
,
nullptr
);
setSignalHandler
(
SIGSEGV
,
&
handler_SIGSEGV
,
nullptr
);
setSignalHandler
(
SIGTERM
,
&
handler_SIGTERM
,
nullptr
);
setSignalHandler
(
SIGFPE
,
&
handler_SIGFPE
,
nullptr
);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static
std
::
map
<
int64_t
,
std
::
set
<
pid_t
>>
worker_pids
=
{};
static
void
error_if_any_worker_fails
()
{
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int
error
;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std
::
set
<
pid_t
>*
pid_set
;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
pid_t
worker_pid
;
siginfo_t
infop
;
// Only check the pids we care about
for
(
auto
&
w
:
worker_pids
)
{
pid_set
=
&
(
w
.
second
);
for
(
auto
pid_it
=
pid_set
->
begin
();
pid_it
!=
pid_set
->
end
();
++
pid_it
)
{
worker_pid
=
*
pid_it
;
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
// and other handlers can get whatever info they want about the child.
infop
.
si_pid
=
0
;
error
=
waitid
(
P_PID
,
worker_pid
,
&
infop
,
WEXITED
|
WNOHANG
|
WNOWAIT
);
// ignore errors and case with no waitable child
if
(
error
<
0
||
infop
.
si_pid
==
0
)
continue
;
if
(
infop
.
si_code
==
CLD_EXITED
&&
infop
.
si_status
!=
EXIT_SUCCESS
)
{
// exit with error
std
::
ostringstream
oss
;
oss
<<
"DataLoader worker (pid "
<<
worker_pid
<<
") exited "
<<
"unexpectedly with exit code "
<<
infop
.
si_status
<<
". "
<<
"Details are lost due to multiprocessing. Rerunning with "
<<
"num_workers=0 may give better error trace."
;
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set
->
clear
();
throw
std
::
runtime_error
(
oss
.
str
());
}
else
if
(
infop
.
si_code
==
CLD_KILLED
||
infop
.
si_code
==
CLD_DUMPED
)
{
// killed by signal
std
::
ostringstream
oss
;
oss
<<
"DataLoader worker (pid "
<<
worker_pid
<<
") is killed "
<<
"by signal: "
<<
strsignal
(
infop
.
si_status
)
<<
". "
;
if
(
infop
.
si_status
==
SIGBUS
)
{
oss
<<
"It is possible that dataloader's workers are out of shared memory. "
<<
"Please try to raise your shared memory limit."
;
}
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set
->
clear
();
throw
std
::
runtime_error
(
oss
.
str
());
}
}
}
}
inline
int64_t
utils_unpackLong
(
PyObject
*
obj
)
{
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int
overflow
;
long
long
value
=
PyLong_AsLongLongAndOverflow
(
obj
,
&
overflow
);
if
(
value
==
-
1
&&
PyErr_Occurred
())
{
throw
py
::
value_error
();
}
if
(
overflow
!=
0
)
{
throw
std
::
runtime_error
(
"Overflow when unpacking long"
);
}
return
(
int64_t
)
value
;
}
// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
// of pids we are interested in.
static
void
set_worker_pids
(
py
::
args
py_args
)
{
PyObject
*
args
=
py_args
.
ptr
();
if
(
PyTuple_GET_SIZE
(
args
)
!=
2
)
{
throw
py
::
type_error
(
"_set_worker_pids expects exactly 2 arguments."
);
}
int64_t
key
=
utils_unpackLong
(
PyTuple_GET_ITEM
(
args
,
0
));
if
(
worker_pids
.
find
(
key
)
!=
worker_pids
.
end
())
{
throw
py
::
value_error
(
"_set_worker_pids should be called only once for each _BaseDataLoaderIter."
);
}
PyObject
*
child_pids
=
PyTuple_GET_ITEM
(
args
,
1
);
if
(
!
PyTuple_Check
(
child_pids
))
{
py
::
print
(
"_set_worker_pids expects a tuple for child_pids, but got: "
,
Py_TYPE
(
child_pids
)
->
tp_name
);
throw
py
::
type_error
(
"_set_worker_pids expects a tuple for child_pids"
);
}
std
::
set
<
pid_t
>
pids_set
=
{};
auto
size
=
PyTuple_GET_SIZE
(
child_pids
);
for
(
int
idx
=
0
;
idx
<
size
;
idx
++
)
{
PyObject
*
obj
=
PyTuple_GET_ITEM
(
child_pids
,
idx
);
pids_set
.
insert
(
static_cast
<
pid_t
>
(
utils_unpackLong
(
obj
)));
}
worker_pids
[
key
]
=
pids_set
;
}
static
void
remove_worker_pids
(
py
::
args
py_args
)
{
PyObject
*
args
=
py_args
.
ptr
();
int64_t
key
=
utils_unpackLong
(
PyTuple_GET_ITEM
(
args
,
0
));
auto
it
=
worker_pids
.
find
(
key
);
if
(
it
==
worker_pids
.
end
())
{
py
::
print
(
"Cannot find worker information for _BaseDataLoaderIter with id :"
,
key
);
throw
py
::
value_error
(
"Cannot find worker information for _BaseDataLoaderIter"
);
}
worker_pids
.
erase
(
it
);
}
#undef SIGNAL_HANDLER
#else
// dummy implementations for windows
static
PyObject
*
set_worker_signal_handlers
(
PyObject
*
module
,
PyObject
*
_ignored
)
{
Py_RETURN_NONE
;
}
static
PyObject
*
set_worker_pids
(
PyObject
*
module
,
PyObject
*
_ignored
)
{
Py_RETURN_NONE
;
}
static
PyObject
*
remove_worker_pids
(
PyObject
*
module
,
PyObject
*
_ignored
)
{
Py_RETURN_NONE
;
}
static
PyObject
*
error_if_any_worker_fails
(
PyObject
*
module
,
PyObject
*
_ignored
)
{
Py_RETURN_NONE
;
}
#endif
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
m
.
def
(
"_set_worker_signal_handlers"
,
&
set_worker_signal_handlers
);
m
.
def
(
"_set_worker_pids"
,
&
set_worker_pids
);
m
.
def
(
"_remove_worker_pids"
,
&
remove_worker_pids
);
m
.
def
(
"_error_if_any_worker_fails"
,
&
error_if_any_worker_fails
);
}
}
// namespace oneflow
oneflow/api/python/utils/tensor_utils.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/utils/tensor_utils.h"
#include "oneflow/api/python/ofblob/ofblob.e.h"
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/switch_func.h"
#include "oneflow/core/common/tensor_buffer.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/framework/consistency_check.h"
#include "oneflow/core/functional/impl/common.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
one
{
Maybe
<
void
>
EagerMirroredTensorZeros
(
const
std
::
shared_ptr
<
Tensor
>&
t
)
{
JUST
(
functional
::
CheckInplaceValid
(
t
));
std
::
shared_ptr
<
MirroredTensor
>
local_tensor
;
if
(
t
->
is_local
())
{
local_tensor
=
JUST
(
t
->
AsMirroredTensor
());
}
else
{
local_tensor
=
JUST
(
t
->
cur_rank_phy_tensor
());
}
CHECK_OR_RETURN
(
local_tensor
->
is_eager
())
<<
"eager tensors supported only"
;
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
JUST
(
builder
->
AccessBlobByCallback
(
local_tensor
,
[](
uint64_t
of_blob_ptr
)
{
auto
*
of_blob
=
reinterpret_cast
<
OfBlob
*>
(
of_blob_ptr
);
of_blob
->
AsyncAutoMemset
(
0
);
},
"mut"
));
return
Maybe
<
void
>::
Ok
();
}));
return
Maybe
<
void
>::
Ok
();
}
template
<
typename
T
>
Maybe
<
void
>
CopyMirroredTensorFromUntypedArray
(
const
std
::
shared_ptr
<
Tensor
>&
tensor
,
PyObject
*
array
)
{
return
CopyBetweenMirroredTensorAndNumpy
<
T
>
(
tensor
,
array
,
BlobNumpyCopyUtil
<
T
>::
From
,
"mut"
,
/*block_host_until_done=*/
false
);
}
Maybe
<
std
::
string
>
GetCopyMirroredTensorToNumpyFuncName
(
DataType
dtype
)
{
using
namespace
oneflow
;
static
const
HashMap
<
int64_t
,
std
::
shared_ptr
<
std
::
string
>>
data_type2func_name
{
#define DATA_TYPE_FUNC_NAME_PAIR(type_cpp, type_proto) \
{type_proto, std::make_shared<std::string>("_copy_to_numpy_" #type_cpp)},
OF_PP_FOR_EACH_TUPLE
(
DATA_TYPE_FUNC_NAME_PAIR
,
POD_DATA_TYPE_SEQ
)
#undef DATA_TYPE_FUNC_NAME_PAIR
};
return
JUST
(
MapAt
(
data_type2func_name
,
static_cast
<
int64_t
>
(
dtype
)));
}
Maybe
<
std
::
string
>
GetCopyMirroredTensorFromNumpyFuncName
(
DataType
dtype
)
{
using
namespace
oneflow
;
static
const
HashMap
<
int64_t
,
std
::
shared_ptr
<
std
::
string
>>
data_type2func_name
{
#define DATA_TYPE_FUNC_NAME_PAIR(type_cpp, type_proto) \
{type_proto, std::make_shared<std::string>("_copy_from_numpy_" #type_cpp)},
OF_PP_FOR_EACH_TUPLE
(
DATA_TYPE_FUNC_NAME_PAIR
,
POD_DATA_TYPE_SEQ
)
#undef DATA_TYPE_FUNC_NAME_PAIR
};
return
JUST
(
MapAt
(
data_type2func_name
,
static_cast
<
int64_t
>
(
dtype
)));
}
Maybe
<
std
::
tuple
<
std
::
vector
<
Shape
>
,
std
::
vector
<
Symbol
<
DType
>>>>
MaybeGetTensorBufferShapesAndDTypes
(
const
std
::
shared_ptr
<
Tensor
>&
t
)
{
const
auto
&
tensor
=
JUST
(
t
->
AsMirroredTensor
());
if
(
tensor
->
dtype
()
!=
DType
::
TensorBuffer
())
{
return
Error
::
RuntimeError
()
<<
"tensor buffer supported only"
;
}
CHECK_OR_RETURN
(
tensor
->
is_eager
())
<<
"eager tensors supported only"
;
std
::
vector
<
Shape
>
shapes
;
std
::
vector
<
Symbol
<
DType
>>
dtypes
;
auto
btb
=
std
::
make_shared
<
BlockingThenBusy
>
(
1
);
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
return
builder
->
SyncAccessBlobByCallback
(
tensor
,
btb
,
[](
uint64_t
)
{},
"const"
);
}));
JUST
(
btb
->
WaitUntilCntEqualZero
(
VirtualMachine
::
GetPredicatorNoMoreInstructionsFinished
()));
const
auto
&
eager_blob_object
=
JUST
(
tensor
->
eager_blob_object
());
const
Shape
&
blob_shape
=
eager_blob_object
->
shape
();
const
auto
*
tensor_buffer_ptr
=
eager_blob_object
->
dptr
<
TensorBuffer
>
();
for
(
int64_t
i
=
0
;
i
<
blob_shape
.
elem_cnt
();
++
i
)
{
const
TensorBuffer
*
tensor_buffer
=
tensor_buffer_ptr
+
i
;
shapes
.
emplace_back
(
tensor_buffer
->
shape
());
dtypes
.
emplace_back
(
DType
::
Get
(
tensor_buffer
->
data_type
()).
GetOrThrow
());
}
return
std
::
make_tuple
(
shapes
,
dtypes
);
}
Maybe
<
void
>
RegisterTensorHook
(
const
std
::
shared_ptr
<
Tensor
>&
self
,
const
AutogradMeta
::
Hook
&
hook
)
{
CHECK_OR_RETURN
(
self
->
requires_grad
())
<<
"cannot register a hook on a tensor that doesn't require gradient"
;
if
(
!
self
->
grad_fn_node
())
{
JUST
(
AddAccumulateFunctionNode
(
self
));
}
self
->
mut_autograd_meta
()
->
add_hook
(
hook
);
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
RegisterTensorPostGradAccumulationHook
(
const
std
::
shared_ptr
<
Tensor
>&
self
,
const
AutogradMeta
::
Hook
&
hook
)
{
if
(
!
self
->
grad_fn_node
())
{
JUST
(
AddAccumulateFunctionNode
(
self
));
}
self
->
mut_autograd_meta
()
->
add_post_grad_accumulation_hook
(
hook
);
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
py
::
tuple
>
TensorGetPyTupleOfSbp
(
const
Tensor
&
tensor
)
{
const
auto
&
nd_sbp
=
JUST
(
tensor
.
nd_sbp
());
const
auto
&
tuple
=
std
::
make_shared
<
py
::
tuple
>
(
nd_sbp
->
sbp_parallel_size
());
for
(
int
i
=
0
;
i
<
nd_sbp
->
sbp_parallel_size
();
++
i
)
{
(
*
tuple
)[
i
]
=
SymbolOf
(
nd_sbp
->
sbp_parallel
(
i
));
}
return
tuple
;
}
#define MAKE_SWITCH_ENTRY(func_name, dtype) func_name<dtype>
DEFINE_STATIC_SWITCH_FUNC
(
Maybe
<
void
>
,
CopyMirroredTensorFromUntypedArray
,
MAKE_SWITCH_ENTRY
,
MAKE_DATA_TYPE_CTRV_SEQ
(
POD_AND_HALF_DATA_TYPE_SEQ
));
Maybe
<
Tensor
>
MakeLocalTensorFromData
(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
const
Optional
<
Symbol
<
Device
>>&
device
,
const
bool
requires_grad
,
const
bool
pin_memory
)
{
PyObject
*
array
=
NULL
;
PyArray_Descr
*
np_dtype
=
dtype
.
has_value
()
?
PyArray_DescrFromType
(
JUST
(
numpy
::
OFDataTypeToNumpyType
(
JUST
(
dtype
)
->
data_type
())))
:
nullptr
;
// PyArray_FromAny steals a reference to np_dtype object, so no need to decref it.
// NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the
// array with NPY_ARRAY_DEFAULT flag is C-style contiguous.
// NPY_ARRAY_FORCECAST is needed otherwise there will a segfault.
array
=
PyArray_FromAny
(
data
,
np_dtype
,
0
,
0
,
NPY_ARRAY_DEFAULT
|
NPY_ARRAY_ENSURECOPY
|
NPY_ARRAY_FORCECAST
,
nullptr
);
if
(
!
array
)
{
return
Error
::
RuntimeError
()
<<
"Can not convert input data to a new numpy array."
;
}
// flow.tensor([1., 2.]).dtype should be flow.float32 rather than flow.float64
if
(
!
PyArray_Check
(
data
))
{
int
np_array_type
=
PyArray_TYPE
(
reinterpret_cast
<
PyArrayObject
*>
(
array
));
// Cast to float if data is double sequence, rather than numpy array.
if
(
np_array_type
==
NPY_DOUBLE
&&
np_dtype
==
nullptr
)
{
PyObject
*
fp32_array
=
PyArray_Cast
(
reinterpret_cast
<
PyArrayObject
*>
(
array
),
NPY_FLOAT
);
Py_DECREF
(
array
);
array
=
fp32_array
;
}
}
auto
*
np_arr
=
reinterpret_cast
<
PyArrayObject
*>
(
array
);
const
npy_intp
*
dims_ptr
=
PyArray_SHAPE
(
np_arr
);
const
Shape
shape
(
DimVector
(
dims_ptr
,
dims_ptr
+
PyArray_NDIM
(
np_arr
)));
DataType
data_type
=
JUST
(
numpy
::
GetOFDataTypeFromNpArray
(
np_arr
));
Symbol
<
Device
>
device_
;
if
(
device
)
{
device_
=
JUST
(
device
);
}
else
{
device_
=
JUST
(
Device
::
New
(
"cpu"
));
}
std
::
shared_ptr
<
Tensor
>
tensor
=
JUST
(
functional
::
Empty
(
shape
,
JUST
(
DType
::
Get
(
data_type
)),
device_
,
/*pin_memory=*/
pin_memory
));
JUST
(
SwitchCopyMirroredTensorFromUntypedArray
(
SwitchCase
(
data_type
),
tensor
,
array
));
Py_DECREF
(
array
);
JUST
(
tensor
->
set_requires_grad
(
requires_grad
));
return
tensor
;
}
namespace
{
Maybe
<
Symbol
<
NdSbp
>>
GetAllBroadcastNdSbp
(
size_t
ndim
)
{
NdSbp
broadcast_nd_sbp
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
broadcast_nd_sbp
.
mutable_sbp_parallel
()
->
Add
()
->
mutable_broadcast_parallel
();
}
return
SymbolOf
(
broadcast_nd_sbp
);
}
auto
*
CachedGetAllBroadcastNdSbp
=
DECORATE
(
&
GetAllBroadcastNdSbp
,
ThreadLocal
);
}
// namespace
Maybe
<
Tensor
>
MakeConsistentTensorFromData
(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
Symbol
<
ParallelDesc
>
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
,
const
bool
requires_grad
)
{
PyObject
*
array
=
NULL
;
if
(
PyArray_Check
(
data
))
{
// Only NPY_CORDER is supported, and returns a new C-style contiguous array.
array
=
PyArray_NewCopy
((
PyArrayObject
*
)
data
,
NPY_CORDER
);
}
else
{
// NPY_ARRAY_DEFAULT is NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED, so the
// array with NPY_ARRAY_DEFAULT flag is C-style contiguous.
array
=
PyArray_FromAny
(
data
,
nullptr
,
0
,
0
,
NPY_ARRAY_DEFAULT
|
NPY_ARRAY_ENSURECOPY
,
nullptr
);
if
(
!
array
)
{
return
Error
::
RuntimeError
()
<<
"Can not convert input data to a numpy array."
;
}
}
auto
*
np_arr
=
reinterpret_cast
<
PyArrayObject
*>
(
array
);
const
npy_intp
*
dims_ptr
=
PyArray_SHAPE
(
np_arr
);
const
Shape
shape
(
DimVector
(
dims_ptr
,
dims_ptr
+
PyArray_NDIM
(
np_arr
)));
DataType
data_type
=
JUST
(
numpy
::
GetOFDataTypeFromNpArray
(
np_arr
));
if
(
placement
->
parallel_num
()
>
1
)
{
const
void
*
buf_ptr
=
PyArray_DATA
(
np_arr
);
size_t
array_size
=
PyArray_SIZE
(
np_arr
);
CHECK_EQ_OR_RETURN
(
array_size
,
shape
.
elem_cnt
());
size_t
byte_size
=
array_size
*
GetSizeOfDataType
(
data_type
);
JUST
(
DataConsistencyCheck
(
buf_ptr
,
byte_size
,
placement
));
}
Symbol
<
Device
>
device
=
JUST
(
Device
::
New
(
placement
->
device_tag
()));
std
::
shared_ptr
<
Tensor
>
local_tensor
=
JUST
(
functional
::
Empty
(
shape
,
JUST
(
DType
::
Get
(
data_type
)),
device
,
/*pin_memory=*/
false
));
JUST
(
SwitchCopyMirroredTensorFromUntypedArray
(
SwitchCase
(
data_type
),
local_tensor
,
array
));
Py_DECREF
(
array
);
// Cast to float if data is double sequence, rather than numpy array.
Symbol
<
DType
>
dtype_
;
if
(
dtype
)
{
dtype_
=
JUST
(
dtype
);
}
else
if
(
!
dtype
&&
data_type
==
DataType
::
kDouble
&&
!
PyArray_Check
(
data
))
{
dtype_
=
DType
::
Float
();
}
if
(
dtype_
)
{
local_tensor
=
JUST
(
functional
::
Cast
(
local_tensor
,
dtype_
,
/*pin_memory=*/
false
));
}
size_t
sbp_dims
=
sbp_tuple
.
size
();
Symbol
<
NdSbp
>
broadcast_nd_sbp
=
JUST
(
CachedGetAllBroadcastNdSbp
(
sbp_dims
));
std
::
shared_ptr
<
Tensor
>
broadcast_tensor
=
JUST
(
functional
::
LocalToConsistent
(
local_tensor
,
placement
,
*
JUST
(
GetSbpList
(
broadcast_nd_sbp
)),
shape
,
local_tensor
->
dtype
()));
std
::
vector
<
Symbol
<
SbpParallel
>>
grad_sbp_tuple
;
auto
consistent_tensor
=
JUST
(
functional
::
ToConsistent
(
broadcast_tensor
,
placement
,
sbp_tuple
,
grad_sbp_tuple
,
/* check_meta */
false
));
JUST
(
consistent_tensor
->
set_requires_grad
(
requires_grad
));
return
consistent_tensor
;
}
Maybe
<
Tensor
>
MakeTensorFromOtherTensor
(
const
std
::
shared_ptr
<
Tensor
>&
other
,
const
bool
pin_memory
)
{
if
(
other
->
is_local
())
{
const
Symbol
<
Device
>&
device
=
JUST
(
other
->
device
());
return
functional
::
Copy
(
other
,
device
->
type
(),
device
->
device_id
(),
pin_memory
);
}
else
{
const
Symbol
<
NdSbp
>&
nd_sbp
=
JUST
(
other
->
nd_sbp
());
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
=
*
JUST
(
GetSbpList
(
nd_sbp
));
std
::
vector
<
Symbol
<
SbpParallel
>>
grad_sbp_tuple
;
// TODO:(zhaoluyang) consistent case support pin_memory
return
functional
::
ToConsistent
(
other
,
JUST
(
other
->
parallel_desc
()),
sbp_tuple
,
grad_sbp_tuple
,
/* check_meta */
false
);
}
}
Maybe
<
Tensor
>
MakeTensorFromOtherTensor
(
const
std
::
shared_ptr
<
Tensor
>&
other
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
const
Optional
<
Symbol
<
Device
>>&
device
,
const
bool
requires_grad
,
const
bool
pin_memory
)
{
std
::
shared_ptr
<
Tensor
>
tensor
;
Symbol
<
Device
>
device_
;
if
(
device
)
{
device_
=
JUST
(
device
);
}
if
(
other
->
is_local
())
{
if
(
!
device
)
{
device_
=
JUST
(
other
->
device
());
}
tensor
=
JUST
(
functional
::
Copy
(
other
,
device_
->
type
(),
device_
->
device_id
(),
pin_memory
&&
!
dtype
.
has_value
()));
}
else
{
tensor
=
JUST
(
functional
::
ConsistentToLocal
(
other
));
if
(
!
device
)
{
device_
=
JUST
(
Device
::
New
(
"cpu"
));
}
tensor
=
JUST
(
functional
::
Copy
(
tensor
,
device_
->
type
(),
device_
->
device_id
(),
pin_memory
&&
!
dtype
.
has_value
()));
}
if
(
dtype
)
{
const
Symbol
<
DType
>&
dtype_
=
JUST
(
dtype
);
if
(
tensor
->
dtype
()
!=
dtype_
)
{
tensor
=
JUST
(
functional
::
Cast
(
tensor
,
dtype_
,
pin_memory
));
}
}
JUST
(
tensor
->
set_requires_grad
(
requires_grad
));
return
tensor
;
}
Maybe
<
Tensor
>
MakeTensorFromOtherTensor
(
const
std
::
shared_ptr
<
Tensor
>&
other
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
,
const
bool
requires_grad
)
{
std
::
vector
<
Symbol
<
SbpParallel
>>
grad_sbp_tuple
;
bool
check_meta
=
other
->
is_consistent
()
?
false
:
true
;
std
::
shared_ptr
<
Tensor
>
tensor
=
JUST
(
functional
::
ToConsistent
(
other
,
placement
,
sbp_tuple
,
grad_sbp_tuple
,
check_meta
));
if
(
dtype
)
{
const
Symbol
<
DType
>&
dtype_
=
JUST
(
dtype
);
if
(
tensor
->
dtype
()
!=
dtype_
)
{
tensor
=
JUST
(
functional
::
Cast
(
tensor
,
dtype_
,
/*pin_memory=*/
false
));
}
}
JUST
(
tensor
->
set_requires_grad
(
requires_grad
));
return
tensor
;
}
}
// namespace one
}
// namespace oneflow
oneflow/api/python/utils/tensor_utils.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_
#define ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_
#include <Python.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include "oneflow/api/python/framework/tensor.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/common/blocking_then_busy.h"
#include "oneflow/core/vm/virtual_machine.h"
#include "oneflow/core/common/foreign_lock_helper.h"
namespace
py
=
pybind11
;
namespace
pybind11
{
// reference: https://github.com/pybind/pybind11/issues/1776
template
<
>
struct
format_descriptor
<
oneflow
::
float16
>
{
static
pybind11
::
dtype
dtype
()
{
handle
ptr
=
detail
::
npy_api
::
get
().
PyArray_DescrFromType_
(
NPY_FLOAT16
);
return
reinterpret_borrow
<
pybind11
::
dtype
>
(
ptr
);
}
static
std
::
string
format
()
{
// following: https://docs.python.org/3/library/struct.html#format-characters
return
"e"
;
}
static
constexpr
auto
name
()
{
return
detail
::
_
(
"float16"
);
}
};
}
// namespace pybind11
namespace
oneflow
{
namespace
one
{
Maybe
<
void
>
EagerMirroredTensorZeros
(
const
std
::
shared_ptr
<
Tensor
>&
t
);
template
<
typename
T
>
inline
static
Maybe
<
PyObject
*>
EagerMirroredTensorToNumpy
(
PyObject
*
py_tensor
)
{
const
auto
&
t
=
PyTensor_Unpack
(
py_tensor
);
std
::
shared_ptr
<
MirroredTensor
>
tensor
=
JUST
(
t
->
AsMirroredTensor
());
CHECK_OR_RETURN
(
JUST
(
tensor
->
device
())
==
JUST
(
Device
::
New
(
"cpu"
)));
CHECK_OR_RETURN
(
tensor
->
is_eager
())
<<
"eager tensors supported only."
;
// set base object attr
py
::
handle
handle
=
py
::
handle
(
py_tensor
);
const
size_t
ndim
=
tensor
->
ndim
();
const
auto
shape
=
numpy
::
OFShapeToNumpyShape
(
tensor
->
shape
()
->
dim_vec
());
// NumPy strides use bytes. OneFlow strides use element counts.
const
auto
stride
=
numpy
::
OFStrideToNumpyStride
(
*
JUST
(
tensor
->
stride
()),
tensor
->
dtype
()
->
data_type
());
T
*
data_ptr
=
nullptr
;
const
auto
&
Callback
=
[
&
](
uint64_t
ofblob_ptr
)
{
data_ptr
=
reinterpret_cast
<
OfBlob
*>
(
ofblob_ptr
)
->
mut_blob
()
->
mut_dptr
<
T
>
();
};
auto
btb
=
std
::
make_shared
<
BlockingThenBusy
>
(
1
);
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
return
builder
->
SyncAccessBlobByCallback
(
tensor
,
btb
,
Callback
,
"mut"
);
}));
JUST
(
btb
->
WaitUntilCntEqualZero
(
VirtualMachine
::
GetPredicatorNoMoreInstructionsFinished
()));
return
py
::
array
(
py
::
buffer_info
(
data_ptr
,
sizeof
(
T
),
py
::
format_descriptor
<
T
>::
format
(),
ndim
,
shape
,
stride
),
handle
)
.
release
()
.
ptr
();
}
template
<
typename
T
>
inline
Maybe
<
void
>
CopyBetweenMirroredTensorAndNumpy
(
const
std
::
shared_ptr
<
Tensor
>&
t
,
PyObject
*
array
,
Maybe
<
void
>
(
*
Copy
)(
uint64_t
,
const
NumPyArrayPtr
&
),
const
std
::
string
&
modifier
,
bool
block_host_until_done
)
{
auto
tensor
=
JUST
(
t
->
AsMirroredTensor
());
CHECK_OR_RETURN
(
tensor
->
is_eager
())
<<
"eager tensors supported only."
;
if
(
block_host_until_done
)
{
NumPyArrayPtr
array_ptr
(
array
);
const
auto
&
Callback
=
[
array_ptr
,
Copy
](
uint64_t
ofblob_ptr
)
{
CHECK_JUST
(
Copy
(
ofblob_ptr
,
array_ptr
));
};
auto
btb
=
std
::
make_shared
<
BlockingThenBusy
>
(
1
);
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
return
builder
->
SyncAccessBlobByCallback
(
tensor
,
btb
,
Callback
,
modifier
);
}));
JUST
(
btb
->
WaitUntilCntEqualZero
(
VirtualMachine
::
GetPredicatorNoMoreInstructionsFinished
()));
}
else
{
Py_INCREF
(
array
);
NumPyArrayPtr
array_ptr
(
array
,
[
array
]()
{
CHECK_JUST
(
Singleton
<
ForeignLockHelper
>::
Get
()
->
WithScopedAcquire
([
&
]()
->
Maybe
<
void
>
{
Py_DECREF
(
array
);
return
Maybe
<
void
>::
Ok
();
}));
});
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
return
builder
->
AccessBlobByCallback
(
tensor
,
[
array_ptr
,
Copy
](
uint64_t
ofblob_ptr
)
{
CHECK_JUST
(
Copy
(
ofblob_ptr
,
array_ptr
));
},
modifier
);
}));
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
std
::
string
>
GetCopyMirroredTensorToNumpyFuncName
(
DataType
dtype
);
Maybe
<
std
::
string
>
GetCopyMirroredTensorFromNumpyFuncName
(
DataType
dtype
);
Maybe
<
std
::
tuple
<
std
::
vector
<
Shape
>
,
std
::
vector
<
Symbol
<
DType
>>>>
MaybeGetTensorBufferShapesAndDTypes
(
const
std
::
shared_ptr
<
Tensor
>&
t
);
Maybe
<
void
>
RegisterTensorHook
(
const
std
::
shared_ptr
<
Tensor
>&
self
,
const
AutogradMeta
::
Hook
&
hook
);
Maybe
<
void
>
RegisterTensorPostGradAccumulationHook
(
const
std
::
shared_ptr
<
Tensor
>&
self
,
const
AutogradMeta
::
Hook
&
hook
);
Maybe
<
py
::
tuple
>
TensorGetPyTupleOfSbp
(
const
Tensor
&
tensor
);
Maybe
<
Tensor
>
MakeLocalTensorFromData
(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
const
Optional
<
Symbol
<
Device
>>&
device
,
const
bool
requires_grad
,
const
bool
pin_memory
);
Maybe
<
Tensor
>
MakeConsistentTensorFromData
(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
Symbol
<
ParallelDesc
>
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
,
const
bool
requires_grad
);
Maybe
<
Tensor
>
MakeTensorFromOtherTensor
(
const
std
::
shared_ptr
<
Tensor
>&
other
,
const
bool
pin_memory
);
Maybe
<
Tensor
>
MakeTensorFromOtherTensor
(
const
std
::
shared_ptr
<
Tensor
>&
other
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
const
Optional
<
Symbol
<
Device
>>&
device
,
const
bool
requires_grad
,
const
bool
pin_memory
);
Maybe
<
Tensor
>
MakeTensorFromOtherTensor
(
const
std
::
shared_ptr
<
Tensor
>&
other
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
,
const
bool
requires_grad
);
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_UTILS_TENSOR_UTILS_H_
oneflow/core/auto_parallel/boxing_collector.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include <string>
#include "oneflow/core/auto_parallel/boxing_collector.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/job/sbp_parallel.pb.h"
#include "oneflow/core/register/blob_desc.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/lazy_mode.h"
namespace
oneflow
{
namespace
{
void
DfsSetNdSbp
(
const
std
::
vector
<::
oneflow
::
SbpParallel
>&
id2sbp_parallel
,
int32_t
depth
,
int32_t
max_depth
,
NdSbp
&
nd_sbp
,
std
::
vector
<
NdSbp
>&
nd_sbp_lists
,
std
::
unordered_map
<::
oneflow
::
NdSbp
,
int32_t
>&
nd_sbp_universe
)
{
if
(
depth
==
max_depth
)
{
nd_sbp_universe
[
nd_sbp
]
=
nd_sbp_lists
.
size
();
nd_sbp_lists
.
push_back
(
nd_sbp
);
}
else
{
for
(
const
auto
&
sbp_parallel
:
id2sbp_parallel
)
{
*
nd_sbp
.
mutable_sbp_parallel
(
depth
)
=
sbp_parallel
;
DfsSetNdSbp
(
id2sbp_parallel
,
depth
+
1
,
max_depth
,
nd_sbp
,
nd_sbp_lists
,
nd_sbp_universe
);
}
}
}
// Let a nd sbp be consistent with the given hierarchy number
Maybe
<
NdSbp
>
SetNdSbpDim
(
NdSbp
nd_sbp
,
int32_t
hierarchy_num
)
{
// Do not need to change
if
(
nd_sbp
.
sbp_parallel_size
()
==
hierarchy_num
)
{
return
nd_sbp
;
}
// (S0, S0) -> S0
if
(
hierarchy_num
==
1
)
{
CHECK_OR_RETURN
(
Is1dSbp
(
nd_sbp
))
<<
NdSbpToString
(
nd_sbp
)
<<
" can not be converted to a 1d sbp!"
;
NdSbp
new_sbp
;
new_sbp
.
add_sbp_parallel
();
*
new_sbp
.
mutable_sbp_parallel
(
0
)
=
nd_sbp
.
sbp_parallel
(
0
);
return
new_sbp
;
}
// S0 -> (S0, S0)
CHECK_EQ_OR_RETURN
(
nd_sbp
.
sbp_parallel_size
(),
1
)
<<
"Illegal nd sbp transform."
;
NdSbp
new_sbp
;
for
(
int32_t
i
=
0
;
i
<
hierarchy_num
;
i
++
)
{
new_sbp
.
add_sbp_parallel
();
*
new_sbp
.
mutable_sbp_parallel
(
i
)
=
nd_sbp
.
sbp_parallel
(
0
);
}
return
new_sbp
;
}
}
// namespace
// A constructor with init, designed for uncustomized boxing collector
BoxingCollector
::
BoxingCollector
(
int32_t
max_axis
)
{
CHECK_JUST
(
Init
(
max_axis
));
}
// Construct a boxing collector with given maximum number of axis
Maybe
<
void
>
BoxingCollector
::
Init
(
int32_t
max_axis
)
{
// Not allowed two-step boxing and disable checking for debugging
if
(
ParseBooleanFromEnv
(
"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"
,
false
))
{
return
Maybe
<
void
>::
Ok
();
}
// Set up at least two split for op graph.
// For a negative example: Resnet50 only have B, P, S(0)
CollectUniverse
(
max_axis
);
GenerateNdSbpList
(
2
);
GenerateMap1d2nd
();
// Get copy cost in lazy mode
LazyMode
::
Guard
enable_lazy_mode
(
true
);
JUST
(
GenerateCombination4SamePlacement
(
3
));
JUST
(
GenerateCombination4DiffHierarchy
(
this
,
this
));
JUST
(
GenerateCombination4DiffPlacement
(
this
,
this
));
return
Maybe
<
void
>::
Ok
();
}
// Customized initialization with given blob and parallel description
Maybe
<
void
>
BoxingCollector
::
Init
(
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
parallel_desc
)
{
CollectUniverse
(
logical_blob_desc
.
shape
().
NumAxes
());
GenerateNdSbpList
(
parallel_desc
.
hierarchy
()
->
NumAxes
());
// Filter out unsuitable middle nodes before computing minimum cost.
JUST
(
FilterNdSbpList4LogicalShape
(
logical_blob_desc
,
*
parallel_desc
.
hierarchy
()));
GenerateMap1d2nd
();
// Get copy cost in lazy mode
LazyMode
::
Guard
enable_lazy_mode
(
true
);
JUST
(
GenerateCombination4SamePlacement
(
5
,
logical_blob_desc
,
parallel_desc
));
return
Maybe
<
void
>::
Ok
();
}
// Collect Sbp Parallel
void
BoxingCollector
::
CollectUniverse
(
const
SbpParallel
&
sbp
)
{
if
(
sbp_parallel_universe_
.
find
(
sbp
)
==
sbp_parallel_universe_
.
end
())
{
int32_t
curr_size
=
sbp_parallel_universe_
.
size
();
sbp_parallel_universe_
[
sbp
]
=
curr_size
;
id2sbp_parallel_
.
push_back
(
sbp
);
}
}
// Find corresponding id for Nd sbp
int32_t
BoxingCollector
::
FindId4NdSbp
(
const
NdSbp
&
nd_sbp
)
{
// Directly search on the nd_sbp_list
if
(
nd_sbp
.
sbp_parallel_size
()
==
hierarchy_num_
)
{
const
auto
&
it_nd_sbp
=
nd_sbp_universe_
.
find
(
nd_sbp
);
if
(
it_nd_sbp
!=
nd_sbp_universe_
.
end
())
{
return
it_nd_sbp
->
second
;
}
else
{
return
-
1
;
}
}
// Find the diagonal node if it could be converted to a 1D sbp
if
(
Is1dSbp
(
nd_sbp
))
{
const
auto
&
it_nd_sbp
=
sbp_parallel_universe_
.
find
(
nd_sbp
.
sbp_parallel
(
0
));
if
(
it_nd_sbp
!=
sbp_parallel_universe_
.
end
())
{
return
id_1d_2_nd_
[
it_nd_sbp
->
second
];
}
}
// Can not be converted to a 1D sbp or not found in the 1D sbp list
return
-
1
;
}
// Set default Sbp list
void
BoxingCollector
::
CollectUniverse
(
int32_t
max_axis
)
{
SbpParallel
sbp
;
sbp
.
mutable_broadcast_parallel
();
CollectUniverse
(
sbp
);
for
(
int32_t
axis
=
0
;
axis
<
max_axis
;
axis
++
)
{
sbp
.
mutable_split_parallel
()
->
set_axis
(
axis
);
CollectUniverse
(
sbp
);
}
sbp
.
mutable_partial_sum_parallel
();
CollectUniverse
(
sbp
);
}
// Generate nd sbp list
void
BoxingCollector
::
GenerateNdSbpList
(
int32_t
hierarchy_num
)
{
// 1D sbp does not support S->P. But it seems that we do not need to deal with it for now.
// And we do not have 3D sbp or higher dimension.
hierarchy_num_
=
hierarchy_num
;
// Generate possible nd_sbp lists
NdSbp
nd_sbp
;
for
(
int32_t
dim_sbp
=
0
;
dim_sbp
<
hierarchy_num
;
dim_sbp
++
)
{
nd_sbp
.
add_sbp_parallel
();
}
DfsSetNdSbp
(
id2sbp_parallel_
,
0
,
hierarchy_num
,
nd_sbp
,
nd_sbp_lists_
,
nd_sbp_universe_
);
}
// Generate the map from 1d sbp to 2d sbp
void
BoxingCollector
::
GenerateMap1d2nd
()
{
// Number of 1d sbp
int32_t
m
=
id2sbp_parallel_
.
size
();
// Generate the id Map from 1d sbp to nd sbp
NdSbp
nd_sbp
;
for
(
int32_t
dim_sbp
=
0
;
dim_sbp
<
hierarchy_num_
;
dim_sbp
++
)
{
nd_sbp
.
add_sbp_parallel
();
}
id_1d_2_nd_
.
resize
(
m
,
-
1
);
for
(
int32_t
id_1d
=
0
;
id_1d
<
m
;
id_1d
++
)
{
for
(
int32_t
dim_sbp
=
0
;
dim_sbp
<
hierarchy_num_
;
dim_sbp
++
)
{
*
nd_sbp
.
mutable_sbp_parallel
(
dim_sbp
)
=
id2sbp_parallel_
[
id_1d
];
}
// NOTE: The 2d sbp might be filtered out already.
const
auto
&
it_
=
nd_sbp_universe_
.
find
(
nd_sbp
);
if
(
it_
!=
nd_sbp_universe_
.
end
())
{
id_1d_2_nd_
[
id_1d
]
=
it_
->
second
;
}
}
}
// Generate the transfer rule for different combinations with the same hierarchy
Maybe
<
void
>
BoxingCollector
::
GenerateCombination4SamePlacement
(
int32_t
max_middle_node_num
)
{
// other parameters
// NOTE: The performance of this function are all the same with different hierarchy
int32_t
world_size
=
GlobalProcessCtx
::
WorldSize
();
Shape
hierarchy44
({
4
*
world_size
,
4
*
world_size
});
std
::
shared_ptr
<
Shape
>
virtual_hierarchy
=
std
::
make_shared
<
Shape
>
(
hierarchy44
);
auto
parallel_desc
=
JUST
(
ParallelDesc
::
New
(
"cpu"
,
{
"0:0-"
+
std
::
to_string
(
hierarchy44
.
elem_cnt
()
-
1
)},
virtual_hierarchy
));
BlobDesc
blob_desc
({
16
,
16
,
16
,
16
},
DataType
::
kInt8
,
/*is_dynamic=*/
false
);
JUST
(
GenerateCombination4SamePlacement
(
max_middle_node_num
,
blob_desc
,
*
parallel_desc
));
return
Maybe
<
void
>::
Ok
();
}
// Generate the transfer rule for different combinations with the same hierarchy
Maybe
<
void
>
BoxingCollector
::
GenerateCombination4SamePlacement
(
int32_t
max_middle_node_num
,
const
BlobDesc
&
blob_desc
,
const
ParallelDesc
&
parallel_desc
)
{
// Store the origin transfer cost information
int32_t
n
=
nd_sbp_lists_
.
size
();
minimum_copy_cost_
.
resize
(
n
);
middle_nodes_
.
resize
(
n
);
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
minimum_copy_cost_
[
i
].
resize
(
n
);
middle_nodes_
[
i
].
resize
(
n
);
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
minimum_copy_cost_
[
i
][
j
]
=
JUST
(
ComputeLazyCopyCostBetweenNdSbp
(
nd_sbp_lists_
[
i
],
nd_sbp_lists_
[
j
],
blob_desc
,
parallel_desc
,
parallel_desc
,
/*requires_same_sbp=*/
false
));
}
}
auto
NotMiddleNode
=
[
&
](
int32_t
i
,
int32_t
j
,
int32_t
k
,
int32_t
middle_node_num_ik
)
->
bool
{
// Not allow i -> i -> j or i -> j -> j.
if
(
k
==
j
||
k
==
i
)
{
return
true
;
}
// We add middle nodes one by one
// Thus, we allow multiple nodes from i to k but we only accept 1 step from k to j.
// i -> ? -> k -> j
if
(
middle_nodes_
[
k
][
j
].
size
()
>
0
)
{
return
true
;
}
// To avoid multiple counting and bugs, the number of middle nodes between i and k
// must be exactly middle_node_num_ik, which is (middle_node_num - 1)
if
(
middle_node_num_ik
)
{
if
(
middle_nodes_
[
i
][
k
].
size
()
==
0
||
middle_nodes_
[
i
][
k
][
0
].
size
()
!=
middle_node_num_ik
)
{
return
true
;
}
}
else
{
if
(
middle_nodes_
[
i
][
k
].
size
()
>
0
)
{
return
true
;
}
}
return
false
;
};
for
(
int32_t
middle_node_num
=
1
;
middle_node_num
<=
max_middle_node_num
;
middle_node_num
++
)
{
int32_t
middle_node_num_ik
=
middle_node_num
-
1
;
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
if
(
minimum_copy_cost_
[
i
][
j
]
<
GetValidMaxCopyCost
())
{
continue
;
}
// Compute the smallest transfer cost
// k is the middle node, i -> k -> j
for
(
int32_t
k
=
0
;
k
<
n
;
k
++
)
{
if
(
NotMiddleNode
(
i
,
j
,
k
,
middle_node_num_ik
))
{
continue
;
}
double
curr_copy_cost
=
minimum_copy_cost_
[
i
][
k
]
+
minimum_copy_cost_
[
k
][
j
];
if
(
curr_copy_cost
<
minimum_copy_cost_
[
i
][
j
])
{
minimum_copy_cost_
[
i
][
j
]
=
curr_copy_cost
;
}
}
// If the minimum copy cost remians infinity, adding one middle node does not make it.
if
(
minimum_copy_cost_
[
i
][
j
]
>
GetValidMaxCopyCost
())
{
continue
;
}
// Find those middle nodes
for
(
int32_t
k
=
0
;
k
<
n
;
k
++
)
{
if
(
NotMiddleNode
(
i
,
j
,
k
,
middle_node_num_ik
))
{
continue
;
}
// Now we start to judge if the edge have a minimum cost
// It needs to be "<=" since we have 0 cost.
// Using "<" would give no middle nodes from (B, B) to any other nd sbp.
if
(
minimum_copy_cost_
[
i
][
k
]
+
minimum_copy_cost_
[
k
][
j
]
<=
minimum_copy_cost_
[
i
][
j
]
*
1.0000001
)
{
// i -> ? -> k
if
(
middle_nodes_
[
i
][
k
].
size
()
>
0
)
{
// We have multiple choices going from i to k
for
(
const
auto
&
middle_node_ik
:
middle_nodes_
[
i
][
k
])
{
middle_nodes_
[
i
][
j
].
push_back
(
middle_node_ik
);
middle_nodes_
[
i
][
j
][
middle_nodes_
[
i
][
j
].
size
()
-
1
].
push_back
(
k
);
}
}
else
{
// We only need one middle node k to reach j from i
middle_nodes_
[
i
][
j
].
push_back
({
k
});
}
}
}
CHECK_OR_RETURN
(
middle_nodes_
[
i
][
j
].
size
()
>
0
)
<<
"No middle nodes given from "
<<
NdSbpToString
(
nd_sbp_lists_
[
i
])
<<
" to "
<<
NdSbpToString
(
nd_sbp_lists_
[
j
])
<<
" in boxing collector"
;
}
}
}
return
Maybe
<
void
>::
Ok
();
}
// Generate the transfer rule for different combinations with different hierarchies on the same
// placement
Maybe
<
void
>
BoxingCollector
::
GenerateCombination4DiffHierarchy
(
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
)
{
// Store the boxing collector pointer
// Search the path that contains one of the diagonal sbp
int32_t
n
=
nd_sbp_lists_
.
size
();
diag_node_diff_hierarchy_
.
resize
(
n
);
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
diag_node_diff_hierarchy_
[
i
].
resize
(
n
);
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
JUST
(
Generate1Combination4DiffHierarchy
(
i
,
j
,
boxing_collector_producer
,
boxing_collector_consumer
,
diag_node_diff_hierarchy_
[
i
][
j
]));
}
}
return
Maybe
<
void
>::
Ok
();
}
// Generate the transfer rule for different combinations with different placements
Maybe
<
void
>
BoxingCollector
::
GenerateCombination4DiffPlacement
(
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
)
{
// Virtual parallel and blob description
int32_t
world_size
=
GlobalProcessCtx
::
WorldSize
();
BlobDesc
blob_desc
({
16
,
16
,
16
,
16
},
DataType
::
kInt8
,
/*is_dynamic=*/
false
);
// Virtual placements before transfer
Shape
in_hierarchy44
({
4
*
world_size
+
1
,
4
*
world_size
});
std
::
shared_ptr
<
Shape
>
in_hierarchy
=
std
::
make_shared
<
Shape
>
(
in_hierarchy44
);
auto
in_parallel_desc
=
JUST
(
ParallelDesc
::
New
(
"cpu"
,
{
"0:0-"
+
std
::
to_string
(
in_hierarchy44
.
elem_cnt
()
-
1
)},
in_hierarchy
));
// Virtual placements after transfer
Shape
out_hierarchy44
({
4
*
world_size
,
4
*
world_size
});
std
::
shared_ptr
<
Shape
>
out_hierarchy
=
std
::
make_shared
<
Shape
>
(
out_hierarchy44
);
auto
out_parallel_desc
=
JUST
(
ParallelDesc
::
New
(
"cpu"
,
{
"0:0-"
+
std
::
to_string
(
out_hierarchy44
.
elem_cnt
()
-
1
)},
out_hierarchy
));
JUST
(
GenerateCombination4DiffPlacement
(
boxing_collector_producer
,
boxing_collector_consumer
,
blob_desc
,
*
in_parallel_desc
,
*
out_parallel_desc
));
return
Maybe
<
void
>::
Ok
();
}
// The cost for transferring a 1D sbp between different placements
Maybe
<
void
>
BoxingCollector
::
ComputeCostFor1DSbpDiffPlacement
(
const
BlobDesc
&
blob_desc
,
const
ParallelDesc
&
in_parallel_desc
,
const
ParallelDesc
&
out_parallel_desc
,
std
::
vector
<
std
::
vector
<
double
>>&
cost_4_diff_placement
)
{
// Number of 1d sbp
int32_t
m
=
id2sbp_parallel_
.
size
();
// Compute the cost while transferring a 1D sbp between different placements
cost_4_diff_placement
.
resize
(
m
);
for
(
int32_t
id_1d_producer
=
0
;
id_1d_producer
<
m
;
id_1d_producer
++
)
{
cost_4_diff_placement
[
id_1d_producer
].
resize
(
m
,
GetMaxVal
<
float
>
());
int32_t
diag_producer
=
id_1d_2_nd_
[
id_1d_producer
];
if
(
diag_producer
<
0
)
{
continue
;
}
for
(
int32_t
id_1d_consumer
=
0
;
id_1d_consumer
<
m
;
id_1d_consumer
++
)
{
int32_t
diag_consumer
=
id_1d_2_nd_
[
id_1d_consumer
];
if
(
diag_consumer
<
0
)
{
continue
;
}
cost_4_diff_placement
[
id_1d_producer
][
id_1d_consumer
]
=
JUST
(
ComputeLazyCopyCostBetweenNdSbp
(
nd_sbp_lists_
[
diag_producer
],
nd_sbp_lists_
[
diag_consumer
],
blob_desc
,
in_parallel_desc
,
out_parallel_desc
,
false
));
}
}
return
Maybe
<
void
>::
Ok
();
}
// Generate the transfer rule for different combinations with different placements
Maybe
<
void
>
BoxingCollector
::
GenerateCombination4DiffPlacement
(
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
,
const
BlobDesc
&
blob_desc
,
const
ParallelDesc
&
in_parallel_desc
,
const
ParallelDesc
&
out_parallel_desc
)
{
// The cost for transferring a 1D sbp between different placements
std
::
vector
<
std
::
vector
<
double
>>
cost_4_diff_placement
;
// Compute the cost while transferring a 1D sbp between different placements
JUST
(
ComputeCostFor1DSbpDiffPlacement
(
blob_desc
,
in_parallel_desc
,
out_parallel_desc
,
cost_4_diff_placement
));
// Search the path that contains two of the diagonal sbp
int32_t
n
=
nd_sbp_lists_
.
size
();
diag_node_diff_placement_
.
resize
(
n
);
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
diag_node_diff_placement_
[
i
].
resize
(
n
);
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
JUST
(
Generate1Combination4DiffPlacement
(
i
,
j
,
boxing_collector_producer
,
boxing_collector_consumer
,
cost_4_diff_placement
,
diag_node_diff_placement_
[
i
][
j
]));
}
}
return
Maybe
<
void
>::
Ok
();
}
// Print the cost and middle nodes
void
BoxingCollector
::
PrintBoxingTables
()
{
if
(
GlobalProcessCtx
::
Rank
()
==
0
)
{
std
::
cout
<<
"===================minimum copy cost=================="
<<
std
::
endl
;
// other parameters
// To be noted that the performance of this function are all the same with different hierarchy
Shape
hierarchy44
({
4
,
4
});
std
::
shared_ptr
<
Shape
>
in_hierarchy
=
std
::
make_shared
<
Shape
>
(
hierarchy44
);
double
logical_blob_size
=
1024.0
;
int32_t
n
=
nd_sbp_lists_
.
size
();
// Print the origin copy cost table
std
::
cout
<<
"Cost
\t
"
;
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
std
::
cout
<<
NdSbpToString
(
nd_sbp_lists_
[
j
])
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
std
::
cout
<<
NdSbpToString
(
nd_sbp_lists_
[
i
])
<<
"
\t
"
;
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
if
(
minimum_copy_cost_
[
i
][
j
]
>
GetValidMaxCopyCost
())
{
std
::
cout
<<
"X
\t
"
;
}
else
{
std
::
cout
<<
minimum_copy_cost_
[
i
][
j
]
<<
"
\t
"
;
}
}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"Original Copy Cost"
<<
std
::
endl
;
std
::
cout
<<
"logical blob size: "
<<
logical_blob_size
<<
std
::
endl
;
std
::
cout
<<
"hierarchy: "
<<
*
in_hierarchy
<<
std
::
endl
;
std
::
cout
<<
"============================middle nodes==========================="
<<
std
::
endl
;
// Print the middle nodes
std
::
cout
<<
"Middle Sbp
\t
"
;
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
std
::
cout
<<
NdSbpToString
(
nd_sbp_lists_
[
j
])
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
std
::
cout
<<
NdSbpToString
(
nd_sbp_lists_
[
i
])
<<
"
\t
"
;
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
if
(
minimum_copy_cost_
[
i
][
j
]
>
GetValidMaxCopyCost
())
{
std
::
cout
<<
"X"
;
}
else
if
(
middle_nodes_
[
i
][
j
].
size
()
>
0
)
{
for
(
int32_t
k
=
0
;
k
<
middle_nodes_
[
i
][
j
].
size
();
k
++
)
{
std
::
cout
<<
NdSbpToString
(
nd_sbp_lists_
[
middle_nodes_
[
i
][
j
][
k
][
0
]]);
for
(
int32_t
l
=
1
;
l
<
middle_nodes_
[
i
][
j
][
k
].
size
();
l
++
)
{
std
::
cout
<<
"->"
<<
NdSbpToString
(
nd_sbp_lists_
[
middle_nodes_
[
i
][
j
][
k
][
l
]]);
}
std
::
cout
<<
"; "
;
}
}
std
::
cout
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"Minimum Copy Cost after second search"
<<
std
::
endl
;
std
::
cout
<<
"logical blob size: "
<<
logical_blob_size
<<
std
::
endl
;
std
::
cout
<<
"hierarchy: "
<<
*
in_hierarchy
<<
std
::
endl
;
std
::
cout
<<
"====================middle nodes for different placement===================="
<<
std
::
endl
;
std
::
cout
<<
"Middle nodes for different placement
\t
"
;
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
std
::
cout
<<
NdSbpToString
(
nd_sbp_lists_
[
j
])
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
std
::
cout
<<
NdSbpToString
(
nd_sbp_lists_
[
i
])
<<
"
\t
"
;
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
if
(
diag_node_diff_placement_
[
i
][
j
].
size
()
>
0
)
{
for
(
int32_t
k
=
0
;
k
<
diag_node_diff_placement_
[
i
][
j
].
size
();
k
++
)
{
std
::
cout
<<
"["
<<
NdSbpToString
(
nd_sbp_lists_
[
diag_node_diff_placement_
[
i
][
j
][
k
][
0
]])
<<
", "
<<
NdSbpToString
(
nd_sbp_lists_
[
diag_node_diff_placement_
[
i
][
j
][
k
][
1
]])
<<
"]; "
;
}
}
std
::
cout
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
"====================middle nodes for different hierarchy===================="
<<
std
::
endl
;
std
::
cout
<<
"Middle nodes for different hierarchy
\t
"
;
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
std
::
cout
<<
NdSbpToString
(
nd_sbp_lists_
[
j
])
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
for
(
int32_t
i
=
0
;
i
<
n
;
i
++
)
{
std
::
cout
<<
NdSbpToString
(
nd_sbp_lists_
[
i
])
<<
"
\t
"
;
for
(
int32_t
j
=
0
;
j
<
n
;
j
++
)
{
if
(
diag_node_diff_hierarchy_
[
i
][
j
].
size
()
>
0
)
{
for
(
int32_t
k
=
0
;
k
<
diag_node_diff_hierarchy_
[
i
][
j
].
size
();
k
++
)
{
std
::
cout
<<
NdSbpToString
(
nd_sbp_lists_
[
diag_node_diff_hierarchy_
[
i
][
j
][
k
][
0
]])
<<
"; "
;
}
}
std
::
cout
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
"================================================"
<<
std
::
endl
;
}
}
// Ask if the boxing algorithm accepts the current sbp combination
Maybe
<
void
>
BoxingCollector
::
AskSbpCombination
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
bool
is_customized
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
,
bool
compute_cost
)
{
middle_sbps
.
clear
();
// Not allowed two-step boxing and disable checking for debugging
if
(
ParseBooleanFromEnv
(
"ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"
,
false
))
{
return
Maybe
<
void
>::
Ok
();
}
// If compute_cost==false + 2D sbp + same placment + nccl logical + not (p->b),
// Use nccl logical send recv instead of middle node.
// Note that in op sbp inference, cost of middle nodes is still used for the moment.
#if defined(WITH_CUDA) || defined(WITH_ROCM)
if
(
compute_cost
==
false
&&
producer_parallel_desc
.
hierarchy
()
->
NumAxes
()
==
2
&&
producer_parallel_desc
==
consumer_parallel_desc
&&
!
(
NdSbpHasPartialParallel
(
sbp_consumer
))
&&
// TODO(): When same dim 0 finished dealing with (*, P) -> (*, S) in nccl logical pass, open
// this condition. When dealing with (P, P) -> (B, S0), middle node will change it to (P, P)
// -> (P, S0) -> (B, S0), neither same dim 0 or send recv in nccl logical pass can deal with
// (P, P) -> (P, S0) at the moment.
// !(NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) &&
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
->
nccl_use_compute_stream
())
{
VLOG
(
3
)
<<
"Middle node insertion is skipped when src sbp is "
<<
NdSbpToString
(
sbp_producer
)
<<
" dst sbp is "
<<
NdSbpToString
(
sbp_consumer
)
<<
", because nccl logical send/recv can handle this."
;
return
Maybe
<
void
>::
Ok
();
}
#endif // WITH_CUDA
// Dealing with 1D sbp to 1D sbp
// Specifically, S -> P.
if
(
Is1dSbp
(
sbp_producer
)
&&
Is1dSbp
(
sbp_consumer
))
{
if
(
sbp_consumer
.
sbp_parallel
(
0
).
has_partial_sum_parallel
())
{
// Support [4]: P <--> [2, 2]: (P, P)
// Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P)
if
(
producer_parallel_desc
.
parallel_num
()
==
consumer_parallel_desc
.
parallel_num
()
&&
sbp_producer
.
sbp_parallel
(
0
).
has_partial_sum_parallel
())
{
return
Maybe
<
void
>::
Ok
();
}
if
(
!
sbp_producer
.
sbp_parallel
(
0
).
has_broadcast_parallel
())
{
// S -> B -> P (Large cost!)
// TODO: Please implement S -> P directly.
// We do not support [3]: P <--> [2, 2]: (P, P) as well.
int32_t
hierarchy_size
=
0
;
if
(
producer_parallel_desc
.
hierarchy
()
->
elem_cnt
()
<
consumer_parallel_desc
.
hierarchy
()
->
elem_cnt
())
{
// The diagonal node uses the parallel description from producer
// (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P)
*
diag_node_pos
=
1
;
hierarchy_size
=
producer_parallel_desc
.
hierarchy
()
->
NumAxes
();
}
else
{
// The diagonal node uses the parallel description from consumer
// S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P)
*
diag_node_pos
=
0
;
hierarchy_size
=
consumer_parallel_desc
.
hierarchy
()
->
NumAxes
();
}
NdSbp
broadcast_nd
;
for
(
int32_t
i
=
0
;
i
<
hierarchy_size
;
i
++
)
{
broadcast_nd
.
add_sbp_parallel
();
broadcast_nd
.
mutable_sbp_parallel
(
i
)
->
mutable_broadcast_parallel
();
}
middle_sbps
.
emplace_back
(
broadcast_nd
);
}
return
Maybe
<
void
>::
Ok
();
}
}
// Middle nodes algorithm supports transfer for different machines or devices or hierarchies
if
(
producer_parallel_desc
!=
consumer_parallel_desc
)
{
JUST
(
AskSbpCombination4DiffPlacement
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
is_customized
,
middle_sbps
,
diag_node_pos
,
compute_cost
));
return
Maybe
<
void
>::
Ok
();
}
// Transfer for the same machines, devices and hierarchy.
if
(
sbp_producer
==
sbp_consumer
)
{
return
Maybe
<
void
>::
Ok
();
}
const
auto
&
parallel_hierarchy
=
producer_parallel_desc
.
hierarchy
();
*
diag_node_pos
=
0
;
// Dealing with nD sbp, n>2
if
(
parallel_hierarchy
->
NumAxes
()
>
2
)
{
CHECK_OR_RETURN
(
compute_cost
)
<<
"Boxing does not support a hierarchy with dimension greater than 2"
;
return
Maybe
<
void
>::
Ok
();
}
// Ask for sbp combination with the same 2-D hierarchy and placement
JUST
(
AskSbpCombination4Same2DPlacement
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
is_customized
,
middle_sbps
,
diag_node_pos
,
compute_cost
));
return
Maybe
<
void
>::
Ok
();
}
// Ask for sbp combination with the same 2-D hierarchy and placement
Maybe
<
void
>
BoxingCollector
::
AskSbpCombination4Same2DPlacement
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
bool
is_customized
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
,
bool
compute_cost
)
{
CHECK_OR_RETURN
(
producer_parallel_desc
==
consumer_parallel_desc
)
<<
"Producer and consumer have different placements, Please use AskSbpCombination directly"
;
middle_sbps
.
clear
();
// Find the 2D sbp id
int32_t
i
=
FindId4NdSbp
(
sbp_producer
);
int32_t
j
=
FindId4NdSbp
(
sbp_consumer
);
// Dealing with 2D sbp
if
(
i
>=
0
&&
j
>=
0
)
{
// Such combination can not be support with limited middle nodes
if
(
minimum_copy_cost_
[
i
][
j
]
>
GetValidMaxCopyCost
())
{
CHECK_OR_RETURN
(
compute_cost
)
<<
"Boxing does not support "
<<
NdSbpToString
(
sbp_producer
)
<<
" -> "
<<
NdSbpToString
(
sbp_consumer
)
<<
" for 2D sbp"
;
return
Maybe
<
void
>::
Ok
();
}
// Current design can deal with such combination. Do not need to insert middle nodes
if
(
middle_nodes_
[
i
][
j
].
size
()
==
0
)
{
return
Maybe
<
void
>::
Ok
();
}
// Find a list of middle nodes with minimum storage
int32_t
min_k
=
-
1
;
double
min_cost
=
GetValidMaxCopyCost
();
for
(
int32_t
k
=
0
;
k
<
middle_nodes_
[
i
][
j
].
size
();
k
++
)
{
double
curr_cost
=
0.0
;
for
(
int32_t
middle_sbp_id
:
middle_nodes_
[
i
][
j
][
k
])
{
Shape
logical_shape
=
logical_blob_desc
.
shape
();
// Storage4NdSbp would modify logical_shape2 as well
curr_cost
+=
Storage4NdSbp
(
nd_sbp_lists_
[
middle_sbp_id
],
logical_shape
,
*
producer_parallel_desc
.
hierarchy
());
if
(
curr_cost
>
GetValidMaxCopyCost
())
{
break
;
}
}
// store k if renew minimum cost
if
(
curr_cost
<
min_cost
)
{
min_k
=
k
;
min_cost
=
curr_cost
;
}
}
// If we found a list of middle nodes with current boxing collector
int32_t
producer_hierarchy_num
=
producer_parallel_desc
.
hierarchy
()
->
NumAxes
();
if
(
min_k
>=
0
)
{
for
(
int32_t
middle_sbp_id
:
middle_nodes_
[
i
][
j
][
min_k
])
{
middle_sbps
.
emplace_back
(
*
JUST
(
SetNdSbpDim
(
nd_sbp_lists_
[
middle_sbp_id
],
producer_hierarchy_num
)));
}
return
Maybe
<
void
>::
Ok
();
}
}
// // If we can not found a list of middle nodes even after customized boxing collector
if
(
is_customized
)
{
CHECK_OR_RETURN
(
compute_cost
)
<<
"Boxing does not support "
<<
NdSbpToString
(
sbp_producer
)
<<
" -> "
<<
NdSbpToString
(
sbp_consumer
)
<<
" for Shape: "
<<
logical_blob_desc
.
shape
();
return
Maybe
<
void
>::
Ok
();
}
// Customized boxing collector and try the algorithm again
BoxingCollector
customized_boxing_collector
;
JUST
(
customized_boxing_collector
.
Init
(
logical_blob_desc
,
producer_parallel_desc
));
JUST
(
customized_boxing_collector
.
AskSbpCombination4Same2DPlacement
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
/*is_customized=*/
true
,
middle_sbps
,
diag_node_pos
,
compute_cost
));
return
Maybe
<
void
>::
Ok
();
}
// Ask for sbp combination with different hierarchies and placements
Maybe
<
void
>
BoxingCollector
::
AskSbpCombination4DiffPlacement
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
bool
is_customized
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
,
bool
compute_cost
)
{
middle_sbps
.
clear
();
// Find the 2D sbp id
int32_t
i
=
FindId4NdSbp
(
sbp_producer
);
int32_t
j
=
FindId4NdSbp
(
sbp_consumer
);
// Different placements: [2, 3] vs 5, or [3, 2] vs [2, 2], or cpu vs cuda
// Different hierarchies: [2, 3] vs 5, or [4, 3] vs [6, 2]
bool
same_placement
=
producer_parallel_desc
.
EqualsIgnoringHierarchy
(
consumer_parallel_desc
);
// Dealing with 2D sbp
if
(
i
>=
0
&&
j
>=
0
)
{
// Pure copy between machines and devices
if
(
i
==
j
&&
(
*
producer_parallel_desc
.
hierarchy
()
==
*
consumer_parallel_desc
.
hierarchy
()))
{
return
Maybe
<
void
>::
Ok
();
}
if
(
same_placement
)
{
// Different hierarchies
CHECK_OR_RETURN
(
diag_node_diff_hierarchy_
.
size
()
>
0
)
<<
"Have not initialzie the combination table for different hierarchies yet! "
"Please run JUST(GenerateCombination4DiffHierarchy(this, this)); "
"before Asking sbp combination for different parallel description."
;
if
(
JUST
(
Ask1Combination4DiffPlacement
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
is_customized
,
middle_sbps
,
diag_node_pos
,
compute_cost
,
this
,
this
,
diag_node_diff_hierarchy_
[
i
][
j
])))
{
return
Maybe
<
void
>::
Ok
();
}
}
else
{
// Different placements
CHECK_OR_RETURN
(
diag_node_diff_placement_
.
size
()
>
0
)
<<
"Have not initialzie the combination table for different hierarchies yet! "
"Please run JUST(GenerateCombination4DiffPlacement(this, this)); "
"before Asking sbp combination for different parallel description."
;
if
(
JUST
(
Ask1Combination4DiffPlacement
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
is_customized
,
middle_sbps
,
diag_node_pos
,
compute_cost
,
this
,
this
,
diag_node_diff_placement_
[
i
][
j
])))
{
return
Maybe
<
void
>::
Ok
();
}
}
}
// Customized boxing collector and try the algorithm again
if
(
is_customized
)
{
CHECK_OR_RETURN
(
compute_cost
)
<<
"Boxing does not support "
<<
NdSbpToString
(
sbp_producer
)
<<
"[hierarchy: "
<<
*
producer_parallel_desc
.
hierarchy
()
<<
"] -> "
<<
NdSbpToString
(
sbp_consumer
)
<<
"[hierarchy: "
<<
*
consumer_parallel_desc
.
hierarchy
()
<<
"] for blob shape: "
<<
logical_blob_desc
.
shape
();
return
Maybe
<
void
>::
Ok
();
}
// Customize boxing collector for producer
BoxingCollector
customized_boxing_collector_producer
;
JUST
(
customized_boxing_collector_producer
.
Init
(
logical_blob_desc
,
producer_parallel_desc
));
// Customize boxing collector for consumer
BoxingCollector
customized_boxing_collector_consumer
;
JUST
(
customized_boxing_collector_consumer
.
Init
(
logical_blob_desc
,
consumer_parallel_desc
));
std
::
vector
<
std
::
vector
<
int32_t
>>
diag_nodes
;
// Generate the combination table for different hierarchies or placements
if
(
same_placement
)
{
JUST
(
customized_boxing_collector_producer
.
Generate1Combination4DiffHierarchy
(
customized_boxing_collector_producer
.
FindId4NdSbp
(
sbp_producer
),
customized_boxing_collector_consumer
.
FindId4NdSbp
(
sbp_consumer
),
&
customized_boxing_collector_producer
,
&
customized_boxing_collector_consumer
,
diag_nodes
));
}
else
{
// Compute the cost while transferring a 1D sbp between different placements
std
::
vector
<
std
::
vector
<
double
>>
cost_4_diff_placement
;
JUST
(
ComputeCostFor1DSbpDiffPlacement
(
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
cost_4_diff_placement
));
JUST
(
customized_boxing_collector_producer
.
Generate1Combination4DiffPlacement
(
customized_boxing_collector_producer
.
FindId4NdSbp
(
sbp_producer
),
customized_boxing_collector_consumer
.
FindId4NdSbp
(
sbp_consumer
),
&
customized_boxing_collector_producer
,
&
customized_boxing_collector_consumer
,
cost_4_diff_placement
,
diag_nodes
));
}
JUST
(
customized_boxing_collector_producer
.
Ask1Combination4DiffPlacement
(
sbp_producer
,
sbp_consumer
,
logical_blob_desc
,
producer_parallel_desc
,
consumer_parallel_desc
,
/*is_customized=*/
true
,
middle_sbps
,
diag_node_pos
,
compute_cost
,
&
customized_boxing_collector_producer
,
&
customized_boxing_collector_consumer
,
diag_nodes
));
return
Maybe
<
void
>::
Ok
();
}
// Generate the transfer rule for one combination with different hierarchies on the same
// placement. id_producer -> id_consumer.
Maybe
<
void
>
BoxingCollector
::
Generate1Combination4DiffHierarchy
(
int32_t
id_producer
,
int32_t
id_consumer
,
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
,
std
::
vector
<
std
::
vector
<
int32_t
>>&
diag_nodes
)
{
// Number of 1d sbp
int32_t
m
=
id2sbp_parallel_
.
size
();
// Search the path that contains one of the diagonal sbp
// minimum number of node
int32_t
min_path_length
=
100
;
// minimum cost
double
min_cost
=
GetValidMaxCopyCost
();
for
(
int32_t
id_1d
=
0
;
id_1d
<
m
;
id_1d
++
)
{
// We do not support [2, 3]: (S0, S1) -> [6]: S0 for a tensor with shape (14, 21)
// Thus, the diagonal node should suit both the hierarchies.
int32_t
diag_producer
=
boxing_collector_producer
->
id_1d_2_nd_
[
id_1d
];
if
(
diag_producer
<
0
)
{
continue
;
}
int32_t
diag_consumer
=
boxing_collector_consumer
->
id_1d_2_nd_
[
id_1d
];
if
(
diag_consumer
<
0
)
{
continue
;
}
// Find the path with minimum number of nodes
int32_t
path_length
=
0
;
// Transfer from id_producer to id_2d
if
(
boxing_collector_producer
->
middle_nodes_
[
id_producer
][
diag_producer
].
size
()
>
0
)
{
path_length
+=
boxing_collector_producer
->
middle_nodes_
[
id_producer
][
diag_producer
][
0
].
size
()
+
1
;
}
else
if
(
id_producer
!=
diag_producer
)
{
path_length
++
;
}
// Transfer from id_2d to id_consumer
if
(
boxing_collector_consumer
->
middle_nodes_
[
diag_consumer
][
id_consumer
].
size
()
>
0
)
{
path_length
+=
boxing_collector_consumer
->
middle_nodes_
[
diag_consumer
][
id_consumer
][
0
].
size
()
+
1
;
}
else
if
(
diag_consumer
!=
id_consumer
)
{
path_length
++
;
}
// Pick the path with minimum copy cost
if
(
path_length
<=
min_path_length
)
{
double
curr_cost
=
boxing_collector_producer
->
minimum_copy_cost_
[
id_producer
][
diag_producer
]
+
boxing_collector_consumer
->
minimum_copy_cost_
[
diag_consumer
][
id_consumer
];
min_path_length
=
path_length
;
// Find a candidate with small cost
if
(
curr_cost
<
min_cost
*
1.0000001
)
{
// Find a smaller cost, clear the previous path.
if
(
curr_cost
<
min_cost
*
0.9999999
)
{
min_cost
=
curr_cost
;
diag_nodes
.
clear
();
}
// Add the current diagonal node
// Asymmetry happens here. We can only store one side of the diagonal node.
// We do not store diag_consumer
diag_nodes
.
push_back
({
diag_producer
,
diag_consumer
});
}
}
}
return
Maybe
<
void
>::
Ok
();
}
// Ask for one combination with different hierarchies and placements
Maybe
<
bool
>
BoxingCollector
::
Ask1Combination4DiffPlacement
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
bool
is_customized
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
,
bool
compute_cost
,
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
,
const
std
::
vector
<
std
::
vector
<
int32_t
>>&
diag_nodes
)
{
// Pick the path with minimum storage for the diagonal node
int32_t
id_producer
=
boxing_collector_producer
->
FindId4NdSbp
(
sbp_producer
);
if
(
id_producer
<
0
)
{
CHECK_OR_RETURN
(
compute_cost
)
<<
"Source data with shape "
<<
logical_blob_desc
.
shape
()
<<
" has an invalid sbp "
<<
NdSbpToString
(
sbp_producer
);
return
false
;
}
int32_t
id_consumer
=
boxing_collector_consumer
->
FindId4NdSbp
(
sbp_consumer
);
if
(
id_consumer
<
0
)
{
CHECK_OR_RETURN
(
compute_cost
)
<<
"Target data with shape "
<<
logical_blob_desc
.
shape
()
<<
" has an invalid sbp "
<<
NdSbpToString
(
sbp_consumer
);
return
false
;
}
middle_sbps
.
clear
();
// NOTE: For simplicity, We do not dig into those storage cost for the other middle nodes at
// this moment.
double
min_cost
=
GetValidMaxCopyCost
();
int32_t
producer_hierarchy_num_axes
=
producer_parallel_desc
.
hierarchy
()
->
NumAxes
();
int32_t
consumer_hierarchy_num_axes
=
consumer_parallel_desc
.
hierarchy
()
->
NumAxes
();
int32_t
min_diag_producer
=
-
1
,
min_diag_consumer
=
-
1
;
for
(
const
auto
&
diag_pair
:
diag_nodes
)
{
Shape
logical_shape
=
logical_blob_desc
.
shape
();
// We do not check whether such shape is valid under two side of the sbp list in the
// middle nodes algorithm. Thus, we need to check them here.
double
curr_cost
=
Storage4NdSbp
(
*
JUST
(
SetNdSbpDim
(
boxing_collector_producer
->
nd_sbp_lists_
[
diag_pair
[
0
]],
producer_hierarchy_num_axes
)),
logical_shape
,
*
producer_parallel_desc
.
hierarchy
());
// Check the shape for both producer and consumer.
logical_shape
=
logical_blob_desc
.
shape
();
curr_cost
+=
Storage4NdSbp
(
*
JUST
(
SetNdSbpDim
(
boxing_collector_consumer
->
nd_sbp_lists_
[
diag_pair
[
1
]],
consumer_hierarchy_num_axes
)),
logical_shape
,
*
consumer_parallel_desc
.
hierarchy
());
if
(
curr_cost
<
min_cost
)
{
min_cost
=
curr_cost
;
min_diag_producer
=
diag_pair
[
0
];
min_diag_consumer
=
diag_pair
[
1
];
}
}
// Different placements: [2, 3] vs 5, or [3, 2] vs [2, 2], or cpu vs cuda
// Different hierarchies: [2, 3] vs 5, or [4, 3] vs [6, 2]
bool
diff_placement
=
!
producer_parallel_desc
.
EqualsIgnoringHierarchy
(
consumer_parallel_desc
);
// If we found a diagonal middle node with current boxing collector
if
(
min_diag_producer
>=
0
)
{
std
::
vector
<
NdSbp
>
middle_sbps_buffer
;
// Find the middle nodes between the producer and the diagonal node
if
(
id_producer
!=
min_diag_producer
)
{
JUST
(
boxing_collector_producer
->
AskSbpCombination
(
sbp_producer
,
boxing_collector_producer
->
nd_sbp_lists_
[
min_diag_producer
],
logical_blob_desc
,
producer_parallel_desc
,
producer_parallel_desc
,
/*is_customized=*/
false
,
middle_sbps_buffer
,
diag_node_pos
,
compute_cost
));
// Add the path into middle_sbps
for
(
auto
&
middle_sbp
:
middle_sbps_buffer
)
{
middle_sbps
.
emplace_back
(
*
JUST
(
SetNdSbpDim
(
middle_sbp
,
producer_hierarchy_num_axes
)));
}
// If different placement,
// or the same placement but with 2D hierarchies
// For example: Oneflow supports [6]: (S0) -> [3, 2]: (S0, S1)
// but does not support [2, 3]: (S0, S0) -> [3, 2]: (S0, S1)
if
(
diff_placement
||
producer_hierarchy_num_axes
>
1
)
{
middle_sbps
.
emplace_back
(
*
JUST
(
SetNdSbpDim
(
boxing_collector_producer
->
nd_sbp_lists_
[
min_diag_producer
],
producer_hierarchy_num_axes
)));
}
}
// If we do not have middle nodes on the consumer side
*
diag_node_pos
=
middle_sbps
.
size
();
// Find the middle nodes between the diagonal node and the consumer
if
(
id_consumer
!=
min_diag_consumer
)
{
JUST
(
boxing_collector_consumer
->
AskSbpCombination
(
boxing_collector_consumer
->
nd_sbp_lists_
[
min_diag_consumer
],
sbp_consumer
,
logical_blob_desc
,
consumer_parallel_desc
,
consumer_parallel_desc
,
/*is_customized=*/
false
,
middle_sbps_buffer
,
diag_node_pos
,
compute_cost
));
// Set the diagonal node position and stop using it as buffer
*
diag_node_pos
=
middle_sbps
.
size
();
// If different placement
if
(
diff_placement
||
consumer_hierarchy_num_axes
>
1
)
{
middle_sbps
.
emplace_back
(
*
JUST
(
SetNdSbpDim
(
boxing_collector_consumer
->
nd_sbp_lists_
[
min_diag_consumer
],
consumer_hierarchy_num_axes
)));
}
// Add the path into middle_sbps
for
(
auto
&
middle_sbp
:
middle_sbps_buffer
)
{
middle_sbps
.
emplace_back
(
*
JUST
(
SetNdSbpDim
(
middle_sbp
,
consumer_hierarchy_num_axes
)));
}
}
return
true
;
}
return
false
;
}
// Generate the transfer rule for one combination with different placements
// id_producer -> id_consumer.
Maybe
<
void
>
BoxingCollector
::
Generate1Combination4DiffPlacement
(
int32_t
id_producer
,
int32_t
id_consumer
,
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
,
const
std
::
vector
<
std
::
vector
<
double
>>&
cost_4_diff_placement
,
std
::
vector
<
std
::
vector
<
int32_t
>>&
diag_nodes
)
{
// Number of 1d sbp
int32_t
m
=
id2sbp_parallel_
.
size
();
// minimum number of node
int32_t
min_path_length
=
100
;
// minimum cost
double
min_cost
=
GetValidMaxCopyCost
();
// Search the path that contains two of the diagonal sbp
// From the producer to the first diagonal node
for
(
int32_t
id_1d_producer
=
0
;
id_1d_producer
<
m
;
id_1d_producer
++
)
{
// We do not support [2, 3]: (S0, S1) -> [6]: S0 for a tensor with shape (14, 21)
// Thus, the diagonal node should suit both the hierarchies.
int32_t
diag_producer
=
boxing_collector_producer
->
id_1d_2_nd_
[
id_1d_producer
];
if
(
diag_producer
<
0
||
boxing_collector_producer
->
minimum_copy_cost_
[
id_producer
][
diag_producer
]
>
GetValidMaxCopyCost
())
{
continue
;
}
// Find the path with minimum number of nodes
int32_t
path_length
=
0
;
// Transfer from id_producer to diag_producer
if
(
boxing_collector_producer
->
middle_nodes_
[
id_producer
][
diag_producer
].
size
()
>
0
)
{
path_length
+=
boxing_collector_producer
->
middle_nodes_
[
id_producer
][
diag_producer
][
0
].
size
()
+
1
;
}
else
if
(
id_producer
!=
diag_producer
)
{
path_length
++
;
}
// pruning
if
(
path_length
>
min_path_length
)
{
continue
;
}
// From the second diagonal node to the consumer
for
(
int32_t
id_1d_consumer
=
0
;
id_1d_consumer
<
m
;
id_1d_consumer
++
)
{
int32_t
diag_consumer
=
boxing_collector_consumer
->
id_1d_2_nd_
[
id_1d_consumer
];
// The diagonal sbp is not supported or no paths exist from the diagonal sbp to the
// consumer or between the two diagonal sbps.
if
(
diag_consumer
<
0
||
boxing_collector_consumer
->
minimum_copy_cost_
[
diag_consumer
][
id_consumer
]
>
GetValidMaxCopyCost
()
||
cost_4_diff_placement
[
id_1d_producer
][
id_1d_consumer
]
>
GetValidMaxCopyCost
())
{
continue
;
}
// Transfer from diag_consumer to id_consumer
int32_t
curr_path_length
=
path_length
;
if
(
boxing_collector_consumer
->
middle_nodes_
[
diag_consumer
][
id_consumer
].
size
()
>
0
)
{
curr_path_length
+=
boxing_collector_consumer
->
middle_nodes_
[
diag_consumer
][
id_consumer
][
0
].
size
()
+
1
;
}
else
if
(
diag_consumer
!=
id_consumer
)
{
curr_path_length
++
;
}
// Pick the path with minimum copy cost
if
(
curr_path_length
<=
min_path_length
)
{
double
curr_cost
=
boxing_collector_producer
->
minimum_copy_cost_
[
id_producer
][
diag_producer
]
+
cost_4_diff_placement
[
id_1d_producer
][
id_1d_consumer
]
+
boxing_collector_consumer
->
minimum_copy_cost_
[
diag_consumer
][
id_consumer
];
min_path_length
=
curr_path_length
;
// Find a candidate with small cost
if
(
curr_cost
<
min_cost
*
1.0000001
)
{
// Find a smaller cost, clear the previous path.
if
(
curr_cost
<
min_cost
*
0.9999999
)
{
min_cost
=
curr_cost
;
diag_nodes
.
clear
();
}
// Add the current diagonal node
// Asymmetry happens here. We can only store one side of the diagonal node.
// We do not store diag_consumer
diag_nodes
.
push_back
({
diag_producer
,
diag_consumer
});
}
}
}
}
return
Maybe
<
void
>::
Ok
();
}
// Filter nd sbp from nd_sbp_lists_ with given logical shape
Maybe
<
void
>
BoxingCollector
::
FilterNdSbpList4LogicalShape
(
const
BlobDesc
&
logical_blob_desc
,
const
Shape
&
parallel_hierarchy
)
{
for
(
int32_t
middle_sbp_id
=
nd_sbp_lists_
.
size
()
-
1
;
middle_sbp_id
>=
0
;
middle_sbp_id
--
)
{
Shape
logical_shape
=
logical_blob_desc
.
shape
();
if
(
JUST
(
FilterNdSbpByLogicalShape
(
nd_sbp_lists_
[
middle_sbp_id
],
logical_shape
,
parallel_hierarchy
)))
{
// Change the value before erasing
// This might be true: nd_sbp_lists_.size() - 1 == middle_sbp_id
nd_sbp_universe_
[
nd_sbp_lists_
[
nd_sbp_lists_
.
size
()
-
1
]]
=
middle_sbp_id
;
nd_sbp_universe_
.
erase
(
nd_sbp_lists_
[
middle_sbp_id
]);
nd_sbp_lists_
[
middle_sbp_id
]
=
nd_sbp_lists_
[
nd_sbp_lists_
.
size
()
-
1
];
nd_sbp_lists_
.
pop_back
();
}
}
return
Maybe
<
void
>::
Ok
();
}
}
// namespace oneflow
oneflow/core/auto_parallel/boxing_collector.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_
#define ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_
#include "oneflow/core/common/hash_container.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/framework/sbp_infer_util.h"
namespace
oneflow
{
class
BoxingCollector
final
{
public:
BoxingCollector
()
=
default
;
~
BoxingCollector
()
=
default
;
// A constructor with init, designed for uncustomized boxing collector
BoxingCollector
(
int32_t
max_axis
);
// Set default Sbp list
void
CollectUniverse
(
int32_t
max_axis
);
// Construct a boxing collector with given maximum number of axis
Maybe
<
void
>
Init
(
int32_t
max_axis
);
// Init with given blob description
Maybe
<
void
>
Init
(
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
parallel_desc
);
// Generate nd sbp list
void
GenerateNdSbpList
(
int32_t
hierarchy_num
);
// Generate the map from 1d sbp to 2d sbp
void
GenerateMap1d2nd
();
// Generate the transfer rule for different combinations with the same hierarchy
Maybe
<
void
>
GenerateCombination4SamePlacement
(
int32_t
max_middle_node_num
);
Maybe
<
void
>
GenerateCombination4SamePlacement
(
int32_t
max_middle_node_num
,
const
BlobDesc
&
blob_desc
,
const
ParallelDesc
&
parallel_desc
);
// Generate the transfer rule for different combinations with different hierarchies
// on the same placement
Maybe
<
void
>
GenerateCombination4DiffHierarchy
(
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
);
// Generate the transfer rule for different combinations with different placements
Maybe
<
void
>
GenerateCombination4DiffPlacement
(
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
);
Maybe
<
void
>
GenerateCombination4DiffPlacement
(
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
,
const
BlobDesc
&
blob_desc
,
const
ParallelDesc
&
in_parallel_desc
,
const
ParallelDesc
&
out_parallel_desc
);
// Print the cost and middle nodes
void
PrintBoxingTables
();
// Ask if the boxing algorithm accepts the current sbp combination
// If is_customized is true and we can not find a middle node list with
// resonable cost, error occurs.
// If compute_cost is true, then no error occur even if no suitable middle nodes paths found.
// For different placements, we would return a diagonal node.
// Before this diagonal node (< *diag_node_pos), we use the parallel description of the producer.
// After this diagonal node (>= *diag_node_pos), we use the parallel description of the consumer.
Maybe
<
void
>
AskSbpCombination
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
bool
is_customized
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
,
bool
compute_cost
);
// Filter nd sbp from nd_sbp_lists_ with given logical shape
Maybe
<
void
>
FilterNdSbpList4LogicalShape
(
const
BlobDesc
&
logical_blob_desc
,
const
Shape
&
parallel_hierarchy
);
private:
// Collect Sbp Parallel
void
CollectUniverse
(
const
SbpParallel
&
sbp
);
// Find corresponding id for Nd sbp
int32_t
FindId4NdSbp
(
const
NdSbp
&
nd_sbp
);
// Ask for sbp combination with the same 2-D hierarchy and placement
Maybe
<
void
>
AskSbpCombination4Same2DPlacement
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
bool
is_customized
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
,
bool
compute_cost
);
// Ask for sbp combination with different hierarchies on the same placement
Maybe
<
void
>
AskSbpCombination4DiffPlacement
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
bool
is_customized
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
,
bool
compute_cost
);
// Generate the transfer rule for one combination with different hierarchies on the same
// placement. id_producer -> id_consumer.
Maybe
<
void
>
Generate1Combination4DiffHierarchy
(
int32_t
id_producer
,
int32_t
id_consumer
,
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
,
std
::
vector
<
std
::
vector
<
int32_t
>>&
diag_nodes
);
// The cost for transferring a 1D sbp between different placements
Maybe
<
void
>
ComputeCostFor1DSbpDiffPlacement
(
const
BlobDesc
&
blob_desc
,
const
ParallelDesc
&
in_parallel_desc
,
const
ParallelDesc
&
out_parallel_desc
,
std
::
vector
<
std
::
vector
<
double
>>&
cost_4_diff_placement
);
// Generate the transfer rule for one combination with different placements
// id_producer -> id_consumer.
Maybe
<
void
>
Generate1Combination4DiffPlacement
(
int32_t
id_producer
,
int32_t
id_consumer
,
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
,
const
std
::
vector
<
std
::
vector
<
double
>>&
cost_4_diff_placement
,
std
::
vector
<
std
::
vector
<
int32_t
>>&
diag_nodes
);
// Ask for one combination with different hierarchies and placements
Maybe
<
bool
>
Ask1Combination4DiffPlacement
(
const
NdSbp
&
sbp_producer
,
const
NdSbp
&
sbp_consumer
,
const
BlobDesc
&
logical_blob_desc
,
const
ParallelDesc
&
producer_parallel_desc
,
const
ParallelDesc
&
consumer_parallel_desc
,
bool
is_customized
,
std
::
vector
<
NdSbp
>&
middle_sbps
,
int32_t
*
diag_node_pos
,
bool
compute_cost
,
BoxingCollector
*
boxing_collector_producer
,
BoxingCollector
*
boxing_collector_consumer
,
const
std
::
vector
<
std
::
vector
<
int32_t
>>&
diag_nodes
);
// Stores all the possible SbpParallel.
HashMap
<::
oneflow
::
SbpParallel
,
int32_t
>
sbp_parallel_universe_
;
// Relationship between id and Sbp Parallel
std
::
vector
<::
oneflow
::
SbpParallel
>
id2sbp_parallel_
;
// minimum cost
// minimum_copy_cost[producer][consumer]
std
::
vector
<
std
::
vector
<
double
>>
minimum_copy_cost_
;
// middle nodes
// middle_nodes_[producer][consumer][different choices] is a vector of middle nodes
// middle_nodes_[producer][consumer][different choices].size() is the minimum number of middle
// nodes that needs to be inserted
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
int32_t
>>>>
middle_nodes_
;
// Stores all the possible NdSbp.
std
::
unordered_map
<::
oneflow
::
NdSbp
,
int32_t
>
nd_sbp_universe_
;
// Relationship between id and Nd Sbp
std
::
vector
<
NdSbp
>
nd_sbp_lists_
;
// The diagonal middle node for differe placements
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
int32_t
>>>>
diag_node_diff_placement_
;
// The diagonal middle node for differe hierarchies in the same placement
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
int32_t
>>>>
diag_node_diff_hierarchy_
;
// Id Map from 1d sbp to 2d sbp
// For example: B -> (B, B), S0 -> (S0, S0)
std
::
vector
<
int32_t
>
id_1d_2_nd_
;
// The sbp size in the combination table
int32_t
hierarchy_num_
;
};
// class BoxingCollector
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTO_PARALLEL_BOXING_COLLECTOR_H_
oneflow/core/autograd/autograd_captured_tensor.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
class
AutogradCapturedTensor
final
:
public
ProxyTensor
<
AutogradCapturedTensor
>
{
public:
static
Maybe
<
AutogradCapturedTensor
>
MakeTensor
(
const
std
::
shared_ptr
<
Tensor
>&
tensor
)
{
if
(
tensor
->
requires_grad
())
{
CHECK_NOTNULL_OR_RETURN
(
tensor
->
grad_fn_node
().
get
())
<<
Error
::
RuntimeError
()
<<
"a grad function node is expected for the captured tensor "
"which requires_grad is True."
;
}
std
::
shared_ptr
<
AutogradCapturedTensor
>
captured_tensor
(
new
AutogradCapturedTensor
(
JUST
(
tensor
->
detach
())));
captured_tensor
->
set_autograd_meta
(
tensor
->
mut_autograd_meta
());
captured_tensor
->
grad_fn_node_
=
tensor
->
mut_grad_fn_node
();
return
captured_tensor
;
}
std
::
shared_ptr
<
const
FunctionNode
>
grad_fn_node
()
const
override
{
return
grad_fn_node_
.
lock
();
}
void
set_grad_fn_node
(
const
std
::
shared_ptr
<
FunctionNode
>&
grad_fn_node
)
override
{
PRINT_BUG_PROMPT_AND_ABORT
();
}
std
::
shared_ptr
<
FunctionNode
>
mut_grad_fn_node
()
override
{
return
grad_fn_node_
.
lock
();
}
std
::
shared_ptr
<
Tensor
>
contiguous
()
const
override
{
const
auto
&
tensor
=
std
::
const_pointer_cast
<
Tensor
>
(
shared_from_this
());
if
(
tensor_
->
is_contiguous
())
{
return
tensor
;
}
return
CHECK_JUST
(
functional
::
ToContiguous
(
tensor
));
}
private:
explicit
AutogradCapturedTensor
(
const
std
::
shared_ptr
<
Tensor
>&
tensor
)
:
ProxyTensor
<
AutogradCapturedTensor
>
(
tensor
)
{}
private:
std
::
weak_ptr
<
FunctionNode
>
grad_fn_node_
;
};
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_CAPTURED_TENSOR_H_
oneflow/core/autograd/autograd_engine.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include <stack>
#include <queue>
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_arg.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/global_param_grad_sync_mode.h"
#include "oneflow/core/common/container_util.h"
namespace
oneflow
{
namespace
one
{
namespace
{
void
GatherFunctionNodes
(
FunctionNode
*
node
,
std
::
stack
<
std
::
shared_ptr
<
FunctionNode
>>&
stack
)
{
for
(
auto
&
prev_node
:
node
->
next_functions
())
{
if
(
prev_node
)
{
if
(
prev_node
.
use_count
()
==
1
)
{
stack
.
push
(
prev_node
);
}
}
}
}
/* NOTE:
* Stack overflows when releasing a very deep computation graph without
* a custom deleter.
*
* For example, here is a very deep computation graph:
* Tensor -> FunctionNode -> Tensor -> FunctionNode -> ... -> Tensor -> FunctionNode
* When releasing the first Tensor, it will trigger the recursive deletion and stack overflow.
*
* So we must set a custom deleter and release them iteratively.
*/
void
FunctionNodeDeleter
(
FunctionNode
*
node
)
{
std
::
stack
<
std
::
shared_ptr
<
FunctionNode
>>
stack
;
node
->
ReleaseData
();
GatherFunctionNodes
(
node
,
stack
);
delete
node
;
while
(
!
stack
.
empty
())
{
auto
now_node
=
std
::
move
(
stack
.
top
());
stack
.
pop
();
now_node
->
ReleaseData
();
GatherFunctionNodes
(
now_node
.
get
(),
stack
);
}
}
bool
IsReadyToRun
(
const
std
::
vector
<
std
::
shared_ptr
<
AutogradMeta
>>&
out_meta_datas
)
{
return
std
::
any_of
(
out_meta_datas
.
begin
(),
out_meta_datas
.
end
(),
[](
const
std
::
shared_ptr
<
AutogradMeta
>&
meta_data
)
{
return
!
meta_data
->
current_grad
()
->
Empty
();
});
}
Maybe
<
void
>
CopyOrAccGrad
(
AutogradMeta
*
autograd_meta
,
bool
autograd_mode
)
{
autograd
::
AutoGradMode
mode
(
autograd_mode
);
auto
current_grad
=
JUST
(
autograd_meta
->
current_grad
()
->
GetAccTensor
({}));
if
(
!
current_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
if
(
autograd_meta
->
acc_grad
())
{
// Should not inplace accumulate grad. For example,
// >>> z = x + y
// >>> p = x / z
// >>> p.sum().backward()
//
// As we know that dx = dz + dp / z and dy = dz, so it will lead to wrong value
// for dy if dx is shared with dz.
const
auto
&
output
=
JUST
(
functional
::
Add
(
autograd_meta
->
acc_grad
(),
current_grad
,
/*alpha=*/
1
,
/*inplace=*/
autograd_meta
->
is_grad_acc_inplace
()));
JUST
(
autograd_meta
->
set_acc_grad
(
output
));
}
else
{
JUST
(
autograd_meta
->
set_acc_grad
(
current_grad
));
}
for
(
const
auto
&
hook
:
autograd_meta
->
post_grad_accumulation_hooks
())
{
auto
new_grad
=
hook
(
autograd_meta
->
acc_grad
());
if
(
new_grad
)
{
JUST
(
autograd_meta
->
set_acc_grad
(
new_grad
));
}
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
RawTorchConsistentTensor
(
const
std
::
shared_ptr
<
one
::
Tensor
>&
tensor
)
{
// Do nothing.
return
Maybe
<
void
>::
Ok
();
}
static
constexpr
auto
*
TorchConsistentTensor
=
DECORATE
(
&
RawTorchConsistentTensor
,
CheckConsistentTensorMeta
);
Maybe
<
void
>
CheckConsistentTensorsMeta
(
const
TensorTuple
&
tensor_tuple
)
{
for
(
const
auto
&
tensor
:
tensor_tuple
)
{
if
(
tensor
->
is_consistent
())
{
JUST
(
TorchConsistentTensor
(
tensor
));
}
}
return
Maybe
<
void
>::
Ok
();
}
}
// namespace
Maybe
<
void
>
AutogradEngine
::
RunBackwardAndSaveGrads4LeafTensorIf
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
{
JUST
(
CheckConsistentTensorsMeta
(
outputs
));
JUST
(
CheckConsistentTensorsMeta
(
out_grads
));
DisableCheckConsistentTensorMetaScope
disable_meta_check
;
return
RunBackwardAndSaveGrads4LeafTensor
(
outputs
,
out_grads
,
retain_graph
,
create_graph
);
}
Maybe
<
TensorTuple
>
AutogradEngine
::
RunBackwardAndReturnInputsTensorGradIf
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
{
JUST
(
CheckConsistentTensorsMeta
(
outputs
));
JUST
(
CheckConsistentTensorsMeta
(
inputs
));
JUST
(
CheckConsistentTensorsMeta
(
out_grads
));
DisableCheckConsistentTensorMetaScope
disable_meta_check
;
return
RunBackwardAndReturnInputsTensorGrad
(
outputs
,
inputs
,
out_grads
,
retain_graph
,
create_graph
);
}
Maybe
<
void
>
FunctionNode
::
AccGrad4RetainGradTensor
()
{
for
(
const
std
::
shared_ptr
<
AutogradMeta
>&
out
:
output_meta_data_
)
{
if
(
out
->
retain_grad
())
{
JUST
(
CopyOrAccGrad
(
out
.
get
(),
/*autograd_mode=*/
false
));
}
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
FunctionNode
::
AccGrad4LeafTensor
(
bool
create_graph
)
{
for
(
auto
i
=
0
;
i
<
output_meta_data_
.
size
();
i
++
)
{
auto
&
out
=
output_meta_data_
[
i
];
if
(
out
->
is_leaf
()
&&
out
->
requires_grad
())
{
JUST
(
CopyOrAccGrad
(
out
.
get
(),
/*autograd_mode=*/
false
));
// control acc_grad to do boxing conditionally
const
auto
&
acc_grad
=
out
->
acc_grad
();
if
(
GlobalGradSyncMode
::
is_enabled
()
&&
acc_grad
->
is_consistent
())
{
auto
&
tensor_info
=
output_tensor_infos_
[
i
];
const
auto
&
placement
=
JUST
(
tensor_info
.
placement
());
const
auto
&
nd_sbp
=
JUST
(
tensor_info
.
sbp
());
JUST
(
out
->
set_acc_grad
(
JUST
(
functional
::
ToConsistent
(
acc_grad
,
placement
,
*
JUST
(
GetSbpList
(
nd_sbp
)),
GetNoneSbpList
(),
/* check_meta */
false
))));
}
}
}
return
Maybe
<
void
>::
Ok
();
}
void
FunctionNode
::
ReleaseOutTensorArgs
()
{
for
(
const
std
::
shared_ptr
<
AutogradMeta
>&
meta_data
:
output_meta_data_
)
{
meta_data
->
current_grad
()
->
Release
();
}
}
Maybe
<
bool
>
FunctionNode
::
Apply
(
bool
create_graph
)
{
CHECK_NOTNULL_OR_RETURN
(
backward_fn_
)
<<
"This FunctionNode with name `"
<<
name
()
<<
"` has been released.
\n
"
<<
"Maybe you try to backward through the node a second time. Specify retain_graph=True when "
"calling .backward() or autograd.grad() the first time."
;
if
(
!
IsReadyToRun
(
output_meta_data_
))
{
return
false
;
}
TensorTuple
input_grads
(
input_meta_data_
.
size
());
TensorTuple
output_grads
(
output_meta_data_
.
size
());
for
(
int
i
=
0
;
i
<
output_meta_data_
.
size
();
++
i
)
{
if
(
output_meta_data_
.
at
(
i
)
->
current_grad
()
->
Empty
())
{
output_grads
.
at
(
i
)
=
JUST
(
output_tensor_infos_
.
at
(
i
).
zeros
());
}
else
{
const
auto
&
hooks
=
JUST
(
oneflow
::
VectorAt
(
output_meta_data_
,
i
))
->
hooks
();
JUST
(
oneflow
::
VectorAt
(
output_grads
,
i
))
=
JUST
(
JUST
(
oneflow
::
VectorAt
(
output_meta_data_
,
i
))
->
current_grad
()
->
GetAccTensor
(
hooks
));
}
}
JUST
(
backward_fn_
->
body
(
output_grads
,
&
input_grads
,
create_graph
));
for
(
int
i
=
0
;
i
<
input_meta_data_
.
size
();
++
i
)
{
if
(
JUST
(
VectorAt
(
input_grads
,
i
)))
{
CHECK_NOTNULL_OR_RETURN
(
input_meta_data_
.
at
(
i
))
<<
name_
<<
" calculate grad for tensor which requires_grad is False. Please submit an issue in "
"`https://github.com/Oneflow-Inc/oneflow/issues` and we will fix it as soon as "
"possible"
;
JUST
(
input_meta_data_
.
at
(
i
)
->
current_grad
()
->
PushPartialTensor
(
input_grads
.
at
(
i
)));
}
}
return
true
;
}
void
GraphFunctionNode
::
ReleaseData
()
{
if
(
backward_fn_
&&
backward_fn_
->
status
())
{
backward_fn_
.
reset
();
}
}
/*static*/
std
::
shared_ptr
<
GraphFunctionNode
>
GraphFunctionNode
::
New
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
BackwardFunction
>&
backward_fn
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
{
auto
node
=
std
::
shared_ptr
<
GraphFunctionNode
>
(
new
GraphFunctionNode
(
name
,
backward_fn
,
inputs
,
outputs
),
FunctionNodeDeleter
);
return
node
;
}
GraphFunctionNode
::
GraphFunctionNode
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
BackwardFunction
>&
backward_fn
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
)
:
FunctionNode
(
name
,
backward_fn
)
{
input_meta_data_
.
resize
(
inputs
.
size
());
next_functions_
.
reserve
(
inputs
.
size
());
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
inputs
.
at
(
i
)
->
requires_grad
())
{
input_meta_data_
.
at
(
i
)
=
inputs
.
at
(
i
)
->
mut_autograd_meta
();
next_functions_
.
emplace_back
(
inputs
.
at
(
i
)
->
mut_grad_fn_node
());
}
}
output_meta_data_
.
resize
(
outputs
.
size
());
output_tensor_infos_
.
reserve
(
outputs
.
size
());
for
(
int
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
const
auto
&
autograd_meta
=
NewAutogradMeta
(
outputs
.
at
(
i
)
->
requires_grad
(),
outputs
.
at
(
i
)
->
is_leaf
());
outputs
.
at
(
i
)
->
set_autograd_meta
(
autograd_meta
);
output_meta_data_
.
at
(
i
)
=
outputs
.
at
(
i
)
->
mut_autograd_meta
();
output_tensor_infos_
.
emplace_back
(
TensorInfo
(
*
outputs
.
at
(
i
)));
}
backward_fn_
=
backward_fn
;
}
GraphTask
::
GraphTask
(
const
TensorTuple
&
outputs
,
bool
retain_graph
,
bool
create_graph
)
:
retain_graph_
(
retain_graph
),
create_graph_
(
create_graph
)
{
roots_
.
reserve
(
outputs
.
size
());
for
(
const
auto
&
out_tensor
:
outputs
)
{
FunctionNode
*
node
=
out_tensor
->
mut_grad_fn_node
().
get
();
roots_
.
emplace_back
(
node
);
dependencies_
.
insert
(
std
::
make_pair
(
node
,
0
));
}
}
// Computes the number of dependencies for each FunctionNode
Maybe
<
void
>
GraphTask
::
ComputeDependencies
()
{
HashSet
<
FunctionNode
*>
seen
;
std
::
stack
<
FunctionNode
*>
stack
;
for
(
FunctionNode
*
node
:
roots_
)
{
stack
.
push
(
node
);
}
while
(
!
stack
.
empty
())
{
FunctionNode
*
node
=
stack
.
top
();
stack
.
pop
();
if
(
/*bool has_seen=*/
!
seen
.
insert
(
node
).
second
)
{
continue
;
}
for
(
const
auto
&
next_grad_fn
:
node
->
next_functions
())
{
FunctionNode
*
next_node
=
next_grad_fn
.
get
();
dependencies_
[
next_node
]
+=
1
;
if
(
seen
.
find
(
next_node
)
==
seen
.
end
())
{
stack
.
push
(
next_node
);
}
}
}
return
Maybe
<
void
>::
Ok
();
}
// Computes the number of dependencies for each FunctionNode and prunes useless FunctionNode
// according to input tensors
Maybe
<
void
>
GraphTask
::
ComputeDependenciesAndPruneNode
(
const
TensorTuple
&
inputs
)
{
struct
NodeFrame
{
explicit
NodeFrame
(
FunctionNode
*
node
)
:
node_
(
node
),
next_function_idx_
(
0
)
{}
FunctionNode
*
node_
;
size_t
next_function_idx_
;
FunctionNode
*
GetNextFunction
()
{
if
(
next_function_idx_
<
node_
->
next_functions
().
size
())
{
next_function_idx_
+=
1
;
return
node_
->
next_functions
().
at
(
next_function_idx_
-
1
).
get
();
}
else
{
return
nullptr
;
}
}
};
for
(
const
auto
&
input
:
inputs
)
{
CHECK_NOTNULL_OR_RETURN
(
input
->
mut_grad_fn_node
().
get
());
need_execute_
.
insert
(
input
->
mut_grad_fn_node
().
get
());
}
HashSet
<
FunctionNode
*>
seen
;
std
::
stack
<
NodeFrame
>
stack
;
// Note: dfs to determine each FunctionNode should execute or not.
for
(
const
auto
&
root
:
roots_
)
{
stack
.
push
(
NodeFrame
(
root
));
}
while
(
!
stack
.
empty
())
{
NodeFrame
&
frame
=
stack
.
top
();
if
(
/*bool has_seen=*/
seen
.
find
(
frame
.
node_
)
!=
seen
.
end
())
{
stack
.
pop
();
continue
;
}
if
(
FunctionNode
*
node
=
frame
.
GetNextFunction
())
{
dependencies_
[
node
]
+=
1
;
if
(
seen
.
find
(
node
)
==
seen
.
end
())
{
stack
.
push
(
NodeFrame
(
node
));
continue
;
// recurse
}
}
else
{
bool
need_execute
=
std
::
any_of
(
frame
.
node_
->
next_functions
().
begin
(),
frame
.
node_
->
next_functions
().
end
(),
[
&
](
const
std
::
shared_ptr
<
FunctionNode
>&
fn
)
{
return
need_execute_
.
find
(
fn
.
get
())
!=
need_execute_
.
end
();
});
if
(
need_execute
)
{
need_execute_
.
insert
(
frame
.
node_
);
}
seen
.
insert
(
frame
.
node_
);
stack
.
pop
();
}
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
GraphTask
::
Apply
(
bool
save_grad_for_leaf
)
{
std
::
queue
<
FunctionNode
*>
queue
;
for
(
FunctionNode
*
node
:
roots_
)
{
if
(
dependencies_
[
node
]
==
0
)
{
queue
.
push
(
node
);
}
}
while
(
!
queue
.
empty
())
{
FunctionNode
*
node
=
queue
.
front
();
queue
.
pop
();
if
(
!
need_execute_
.
empty
()
&&
need_execute_
.
find
(
node
)
==
need_execute_
.
end
())
{
node
->
ReleaseOutTensorArgs
();
continue
;
}
if
(
/*bool not_ready_to_apply=*/
!
(
JUST
(
node
->
Apply
(
create_graph_
))))
{
continue
;
}
if
(
save_grad_for_leaf
)
{
JUST
(
node
->
AccGrad4LeafTensor
(
create_graph_
));
}
JUST
(
node
->
AccGrad4RetainGradTensor
());
node
->
ReleaseOutTensorArgs
();
if
(
!
retain_graph_
)
{
node
->
ReleaseData
();
}
for
(
const
auto
&
next_grad_fn
:
node
->
next_functions
())
{
FunctionNode
*
next_node
=
next_grad_fn
.
get
();
dependencies_
[
next_node
]
-=
1
;
if
(
dependencies_
[
next_node
]
==
0
)
{
queue
.
push
(
next_node
);
}
}
}
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
GraphAutogradEngine
::
RunBackwardAndSaveGrads4LeafTensor
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
{
for
(
int
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
JUST
(
JUST
(
outputs
.
at
(
i
)
->
current_grad
())
->
PushPartialTensor
(
out_grads
.
at
(
i
)));
}
GraphTask
graph_task
(
outputs
,
retain_graph
,
create_graph
);
JUST
(
graph_task
.
ComputeDependencies
());
JUST
(
graph_task
.
Apply
(
/*save_grad_for_leaf=*/
true
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
TensorTuple
>
GraphAutogradEngine
::
RunBackwardAndReturnInputsTensorGrad
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
{
std
::
shared_ptr
<
TensorTuple
>
input_current_grad
=
std
::
make_shared
<
TensorTuple
>
(
inputs
.
size
());
GraphTask
graph_task
(
outputs
,
retain_graph
,
create_graph
);
std
::
vector
<
bool
>
ori_retain_grad
(
inputs
.
size
());
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
ori_retain_grad
.
at
(
i
)
=
inputs
.
at
(
i
)
->
retain_grad
();
JUST
(
inputs
.
at
(
i
)
->
set_retain_grad
(
true
));
}
for
(
int
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
JUST
(
JUST
(
outputs
.
at
(
i
)
->
current_grad
())
->
PushPartialTensor
(
out_grads
.
at
(
i
)));
}
JUST
(
graph_task
.
ComputeDependenciesAndPruneNode
(
inputs
));
JUST
(
graph_task
.
Apply
(
/*save_grad_for_leaf=*/
false
));
// Gets input grads and resume retain_grad
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
input_current_grad
->
at
(
i
)
=
JUST
(
inputs
.
at
(
i
)
->
acc_grad
());
if
(
!
ori_retain_grad
.
at
(
i
))
{
JUST
(
inputs
.
at
(
i
)
->
set_acc_grad
(
nullptr
));
JUST
(
inputs
.
at
(
i
)
->
set_retain_grad
(
false
));
}
}
return
input_current_grad
;
}
Maybe
<
FunctionNode
>
GraphAutogradEngine
::
AddNode
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
BackwardFunction
>&
backward_fn
,
const
TensorTuple
&
inputs
,
TensorTuple
*
outputs
)
{
// Firstly push function_node of tensor in stack which is leaf and requires_grad
for
(
const
std
::
shared_ptr
<
Tensor
>&
in_tensor
:
inputs
)
{
if
(
in_tensor
->
is_leaf
()
&&
in_tensor
->
requires_grad
())
{
if
(
!
in_tensor
->
grad_fn_node
())
{
JUST
(
AddAccumulateFunctionNode
(
in_tensor
));
}
}
}
std
::
shared_ptr
<
FunctionNode
>
func_node
=
GraphFunctionNode
::
New
(
name
,
backward_fn
,
inputs
,
*
outputs
);
for
(
const
std
::
shared_ptr
<
Tensor
>&
out_tensor
:
*
outputs
)
{
out_tensor
->
set_grad_fn_node
(
func_node
);
}
return
func_node
;
}
AutogradEngine
*
GetThreadLocalAutogradEngine
()
{
thread_local
static
GraphAutogradEngine
autograd_engine
;
return
&
autograd_engine
;
}
Maybe
<
void
>
AddAccumulateFunctionNode
(
const
std
::
shared_ptr
<
Tensor
>&
tensor
)
{
auto
backward_fn
=
std
::
make_shared
<
BackwardFunction
>
();
backward_fn
->
body
=
[
=
](
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
,
bool
create_graph
)
->
Maybe
<
void
>
{
return
Maybe
<
void
>::
Ok
();
};
backward_fn
->
status
=
[]()
{
return
false
;
};
tensor
->
set_grad_fn_node
(
GraphFunctionNode
::
New
(
"accumulate_grad"
,
backward_fn
,
/*inputs=*/
TensorTuple
{},
/*outputs*/
TensorTuple
{
tensor
}));
return
Maybe
<
void
>::
Ok
();
}
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/autograd_engine.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
#include <list>
#include <vector>
#include <memory>
#include <functional>
#include "oneflow/core/common/util.h"
#include "oneflow/core/autograd/autograd_meta.h"
namespace
oneflow
{
namespace
one
{
class
Tensor
;
class
TensorTuple
;
using
CaptureStatus
=
bool
;
struct
BackwardFunction
{
std
::
function
<
Maybe
<
void
>
(
const
TensorTuple
&
,
TensorTuple
*
,
bool
)
>
body
;
std
::
function
<
CaptureStatus
()
>
status
;
};
// Calculates one backward op
class
FunctionNode
{
public:
virtual
~
FunctionNode
()
=
default
;
Maybe
<
bool
>
Apply
(
bool
create_graph
);
Maybe
<
void
>
AccGrad4LeafTensor
(
bool
create_graph
);
Maybe
<
void
>
AccGrad4RetainGradTensor
();
void
ReleaseOutTensorArgs
();
// Releases the eventual c++ std::function for backward if retain_graph=False to avoid calling
// `Apply` in second time
virtual
void
ReleaseData
()
=
0
;
const
std
::
vector
<
std
::
shared_ptr
<
FunctionNode
>>&
next_functions
()
const
{
return
next_functions_
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
protected:
explicit
FunctionNode
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
BackwardFunction
>&
backward_fn
)
:
name_
(
name
),
backward_fn_
(
backward_fn
)
{}
const
std
::
string
name_
;
std
::
vector
<
std
::
shared_ptr
<
FunctionNode
>>
next_functions_
;
std
::
vector
<
std
::
shared_ptr
<
AutogradMeta
>>
input_meta_data_
;
std
::
vector
<
std
::
shared_ptr
<
AutogradMeta
>>
output_meta_data_
;
std
::
vector
<
TensorInfo
>
output_tensor_infos_
;
// Actual backward function builds in `AutogradInterpreter` to calculate one backward op
std
::
shared_ptr
<
BackwardFunction
>
backward_fn_
;
};
class
AutogradEngine
{
public:
virtual
~
AutogradEngine
()
=
default
;
Maybe
<
void
>
RunBackwardAndSaveGrads4LeafTensorIf
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
);
Maybe
<
TensorTuple
>
RunBackwardAndReturnInputsTensorGradIf
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
);
virtual
void
ClearEngine
()
=
0
;
// Builds FunctionNode, binding to all `outputs_` tensors and saving in AutogradEngine
virtual
Maybe
<
FunctionNode
>
AddNode
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
BackwardFunction
>&
backward_fn
,
const
TensorTuple
&
inputs
,
TensorTuple
*
outputs
)
=
0
;
protected:
AutogradEngine
()
=
default
;
private:
virtual
Maybe
<
void
>
RunBackwardAndSaveGrads4LeafTensor
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
=
0
;
virtual
Maybe
<
TensorTuple
>
RunBackwardAndReturnInputsTensorGrad
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
=
0
;
};
// Graph Autograd Node and Engine
class
GraphFunctionNode
final
:
public
FunctionNode
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
GraphFunctionNode
);
static
std
::
shared_ptr
<
GraphFunctionNode
>
New
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
BackwardFunction
>&
backward_fn
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
);
GraphFunctionNode
()
=
delete
;
~
GraphFunctionNode
()
override
=
default
;
void
ReleaseData
()
override
;
private:
GraphFunctionNode
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
BackwardFunction
>&
backward_fn
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
);
};
class
GraphTask
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
GraphTask
);
GraphTask
()
=
delete
;
GraphTask
(
const
TensorTuple
&
outputs
,
bool
retain_graph
,
bool
create_graph
);
Maybe
<
void
>
ComputeDependencies
();
Maybe
<
void
>
ComputeDependenciesAndPruneNode
(
const
TensorTuple
&
inputs
);
Maybe
<
void
>
Apply
(
bool
save_grad_for_leaf
);
private:
bool
retain_graph_
;
bool
create_graph_
;
std
::
vector
<
FunctionNode
*>
roots_
;
HashMap
<
FunctionNode
*
,
int
>
dependencies_
;
HashSet
<
FunctionNode
*>
need_execute_
;
};
class
GraphAutogradEngine
final
:
public
AutogradEngine
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
GraphAutogradEngine
);
GraphAutogradEngine
()
=
default
;
~
GraphAutogradEngine
()
override
=
default
;
void
ClearEngine
()
override
{};
Maybe
<
FunctionNode
>
AddNode
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
BackwardFunction
>&
backward_fn
,
const
TensorTuple
&
inputs
,
TensorTuple
*
outputs
)
override
;
private:
Maybe
<
void
>
RunBackwardAndSaveGrads4LeafTensor
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
override
;
Maybe
<
TensorTuple
>
RunBackwardAndReturnInputsTensorGrad
(
const
TensorTuple
&
outputs
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
out_grads
,
bool
retain_graph
,
bool
create_graph
)
override
;
};
AutogradEngine
*
GetThreadLocalAutogradEngine
();
Maybe
<
void
>
AddAccumulateFunctionNode
(
const
std
::
shared_ptr
<
Tensor
>&
tensor
);
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_ENGINE_H_
oneflow/core/autograd/autograd_function.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/autograd/autograd_function.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
namespace
oneflow
{
namespace
one
{
/*static*/
Maybe
<
TensorTuple
>
AutogradFunctionBase
::
Apply
(
const
std
::
string
&
name
,
const
FType
&
forward_fn
,
const
FType
&
backward_fn
,
const
TensorTuple
&
inputs
)
{
std
::
shared_ptr
<
TensorTuple
>
outputs
=
std
::
make_shared
<
TensorTuple
>
();
const
auto
&
op
=
JUST
(
FunctionOpExpr
::
New
(
name
,
forward_fn
,
backward_fn
));
JUST
(
OpInterpUtil
::
Dispatch
(
*
op
,
inputs
,
outputs
.
get
(),
{}));
const
HashSet
<
Tensor
*>&
non_differentiable_tensors
=
op
->
state
()
->
NonDifferentiableTensors
();
for
(
const
auto
&
tensor
:
*
outputs
)
{
if
(
non_differentiable_tensors
.
find
(
tensor
.
get
())
!=
non_differentiable_tensors
.
end
())
{
JUST
(
tensor
->
set_requires_grad
(
false
));
}
}
return
outputs
;
}
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/autograd_function.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_
#include "oneflow/core/common/util.h"
namespace
oneflow
{
namespace
one
{
class
TensorTuple
;
class
FunctionAutoGradCaptureState
;
class
FunctionOpExpr
;
class
AutogradFunctionBase
{
public:
using
FType
=
std
::
function
<
std
::
shared_ptr
<
TensorTuple
>
(
const
std
::
shared_ptr
<
FunctionAutoGradCaptureState
>&
,
const
TensorTuple
&
)
>
;
AutogradFunctionBase
()
=
default
;
virtual
~
AutogradFunctionBase
()
=
default
;
static
Maybe
<
TensorTuple
>
Apply
(
const
std
::
string
&
name
,
const
FType
&
forward_fn
,
const
FType
&
backward_fn
,
const
TensorTuple
&
inputs
);
};
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_FUNCTION_H_
oneflow/core/autograd/autograd_meta.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/tensor_arg.h"
#include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
TensorInfo
::
TensorInfo
(
const
Tensor
&
tensor
)
:
shape_
(
tensor
.
shape
()),
dtype_
(
tensor
.
dtype
())
{
if
(
TRY
(
tensor
.
device
()).
IsOk
())
{
device_
=
CHECK_JUST
(
tensor
.
device
());
}
if
(
TRY
(
tensor
.
parallel_desc
()).
IsOk
())
{
parallel_desc_
=
CHECK_JUST
(
tensor
.
parallel_desc
());
}
if
(
TRY
(
tensor
.
nd_sbp
()).
IsOk
())
{
nd_sbp_
=
CHECK_JUST
(
tensor
.
nd_sbp
());
}
}
Maybe
<
const
std
::
vector
<
Symbol
<
SbpParallel
>>&>
GetSbpTuple
(
Symbol
<
NdSbp
>
nd_sbp
)
{
static
thread_local
HashMap
<
Symbol
<
NdSbp
>
,
std
::
vector
<
Symbol
<
SbpParallel
>>>
map
;
auto
iter
=
map
.
find
(
nd_sbp
);
if
(
iter
==
map
.
end
())
{
std
::
vector
<
Symbol
<
SbpParallel
>>
sbp_tuple
;
sbp_tuple
.
reserve
(
nd_sbp
->
sbp_parallel
().
size
());
for
(
const
auto
&
sbp_parallel
:
nd_sbp
->
sbp_parallel
())
{
sbp_tuple
.
push_back
(
SymbolOf
(
sbp_parallel
));
}
iter
=
map
.
emplace
(
nd_sbp
,
sbp_tuple
).
first
;
}
return
iter
->
second
;
}
Maybe
<
Tensor
>
TensorInfo
::
zeros
()
const
{
if
(
device_
.
has_value
())
{
const
auto
&
device
=
JUST
(
device_
);
return
functional
::
Constant
(
*
shape_
.
get
(),
0
,
dtype_
,
device
);
}
else
{
const
auto
&
parallel_desc
=
JUST
(
parallel_desc_
);
const
auto
&
nd_sbp
=
JUST
(
nd_sbp_
);
const
auto
&
sbp_tuple
=
JUST
(
GetSbpTuple
(
nd_sbp
));
return
functional
::
ConsistentConstant
(
*
shape_
.
get
(),
0
,
dtype_
,
parallel_desc
,
sbp_tuple
);
}
}
AutogradMeta
::
AutogradMeta
(
bool
requires_grad
,
bool
is_leaf
)
:
is_leaf_
(
is_leaf
),
requires_grad_
(
requires_grad
),
retain_grad_
(
false
),
is_grad_acc_inplace_
(
false
),
current_grad_
(
new
TensorArg
)
{}
Maybe
<
void
>
AutogradMeta
::
set_acc_grad
(
const
std
::
shared_ptr
<
Tensor
>&
grad
)
{
if
(
const
auto
&
static_zeros_tensor
=
std
::
dynamic_pointer_cast
<
StaticZerosTensor
>
(
grad
))
{
acc_grad_
=
JUST
(
static_zeros_tensor
->
AsMirroredTensor
());
}
else
{
acc_grad_
=
grad
;
}
return
Maybe
<
void
>::
Ok
();
}
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/autograd_meta.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_
#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_
#include <memory>
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/optional.h"
namespace
oneflow
{
class
Shape
;
class
Device
;
class
ParallelDesc
;
class
NdSbp
;
namespace
one
{
class
Tensor
;
class
TensorArg
;
class
MirroredTensor
;
class
AutogradMeta
final
{
public:
AutogradMeta
()
=
delete
;
AutogradMeta
(
bool
requires_grad
,
bool
is_leaf
);
// Getters
const
std
::
shared_ptr
<
Tensor
>&
acc_grad
()
const
{
return
acc_grad_
;
}
const
std
::
shared_ptr
<
TensorArg
>&
current_grad
()
const
{
return
current_grad_
;
}
bool
is_grad_acc_inplace
()
const
{
return
is_grad_acc_inplace_
;
}
bool
requires_grad
()
const
{
return
requires_grad_
;
}
bool
is_leaf
()
const
{
return
is_leaf_
;
}
bool
retain_grad
()
const
{
return
retain_grad_
;
}
using
Hook
=
std
::
function
<
std
::
shared_ptr
<
Tensor
>
(
const
std
::
shared_ptr
<
const
Tensor
>&
)
>
;
const
std
::
vector
<
Hook
>&
hooks
()
const
{
return
hooks_
;
}
const
std
::
vector
<
Hook
>&
post_grad_accumulation_hooks
()
const
{
return
post_grad_accumulation_hooks_
;
}
// Setters
Maybe
<
void
>
set_acc_grad
(
const
std
::
shared_ptr
<
Tensor
>&
grad
);
std
::
shared_ptr
<
Tensor
>
mut_acc_grad
()
{
return
acc_grad_
;
}
void
set_is_grad_acc_inplace
(
bool
is_inplace
)
{
is_grad_acc_inplace_
=
is_inplace
;
}
void
set_requires_grad
(
bool
requires_grad
)
{
requires_grad_
=
requires_grad
;
}
void
set_retain_grad
(
bool
retain_grad
)
{
retain_grad_
=
retain_grad
;
}
void
set_is_leaf
(
bool
is_leaf
)
{
is_leaf_
=
is_leaf
;
}
void
add_hook
(
const
Hook
&
hook
)
{
hooks_
.
emplace_back
(
hook
);
}
void
add_post_grad_accumulation_hook
(
const
Hook
&
hook
)
{
post_grad_accumulation_hooks_
.
emplace_back
(
hook
);
}
private:
bool
is_leaf_
;
// Only meaningful on leaf Tensors (must be false otherwise)
bool
requires_grad_
;
// Only meaningful on non_leaf Tensors (must be false otherwise)
bool
retain_grad_
;
// Control whether grad accumulation is inplace. Don't change it
// unless you know what you are doing
bool
is_grad_acc_inplace_
;
std
::
shared_ptr
<
Tensor
>
acc_grad_
;
std
::
shared_ptr
<
TensorArg
>
current_grad_
;
std
::
vector
<
Hook
>
hooks_
;
std
::
vector
<
Hook
>
post_grad_accumulation_hooks_
;
};
inline
std
::
shared_ptr
<
AutogradMeta
>
NewAutogradMeta
(
bool
requires_grad
,
bool
is_leaf
)
{
return
std
::
shared_ptr
<
AutogradMeta
>
(
new
AutogradMeta
(
requires_grad
,
is_leaf
));
}
class
TensorInfo
final
{
public:
TensorInfo
()
=
delete
;
explicit
TensorInfo
(
const
Tensor
&
tensor
);
Maybe
<
Tensor
>
zeros
()
const
;
Optional
<
Symbol
<
ParallelDesc
>>
placement
()
const
{
return
parallel_desc_
;
}
Optional
<
Symbol
<
NdSbp
>>
sbp
()
const
{
return
nd_sbp_
;
}
private:
std
::
shared_ptr
<
const
Shape
>
shape_
;
Symbol
<
DType
>
dtype_
;
Optional
<
Symbol
<
Device
>>
device_
;
// for local tensor
Optional
<
Symbol
<
ParallelDesc
>>
parallel_desc_
;
// for consistent tensor
Optional
<
Symbol
<
NdSbp
>>
nd_sbp_
;
// for consistent tensor
};
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_
Prev
1
…
12
13
14
15
16
17
18
19
20
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment