Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
21d47d0e
Commit
21d47d0e
authored
Oct 24, 2022
by
yuguo
Browse files
Oneflow 0.8 for DCU
parents
Changes
556
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3282 additions
and
0 deletions
+3282
-0
oneflow/api/python/framework/framework.cpp
oneflow/api/python/framework/framework.cpp
+55
-0
oneflow/api/python/framework/framework.h
oneflow/api/python/framework/framework.h
+148
-0
oneflow/api/python/framework/id_util.cpp
oneflow/api/python/framework/id_util.cpp
+26
-0
oneflow/api/python/framework/instructions_builder.cpp
oneflow/api/python/framework/instructions_builder.cpp
+89
-0
oneflow/api/python/framework/job_instance.cpp
oneflow/api/python/framework/job_instance.cpp
+72
-0
oneflow/api/python/framework/nn_graph.cpp
oneflow/api/python/framework/nn_graph.cpp
+109
-0
oneflow/api/python/framework/one_embedding.cpp
oneflow/api/python/framework/one_embedding.cpp
+359
-0
oneflow/api/python/framework/op_builder.cpp
oneflow/api/python/framework/op_builder.cpp
+51
-0
oneflow/api/python/framework/op_expr.cpp
oneflow/api/python/framework/op_expr.cpp
+77
-0
oneflow/api/python/framework/parallel_conf_util.cpp
oneflow/api/python/framework/parallel_conf_util.cpp
+31
-0
oneflow/api/python/framework/py_kernel_registry.cpp
oneflow/api/python/framework/py_kernel_registry.cpp
+28
-0
oneflow/api/python/framework/random_generator.cpp
oneflow/api/python/framework/random_generator.cpp
+73
-0
oneflow/api/python/framework/scope_util.cpp
oneflow/api/python/framework/scope_util.cpp
+39
-0
oneflow/api/python/framework/session_util.cpp
oneflow/api/python/framework/session_util.cpp
+41
-0
oneflow/api/python/framework/shut_down_util.cpp
oneflow/api/python/framework/shut_down_util.cpp
+28
-0
oneflow/api/python/framework/size.cpp
oneflow/api/python/framework/size.cpp
+202
-0
oneflow/api/python/framework/size.h
oneflow/api/python/framework/size.h
+121
-0
oneflow/api/python/framework/tensor.cpp
oneflow/api/python/framework/tensor.cpp
+711
-0
oneflow/api/python/framework/tensor.h
oneflow/api/python/framework/tensor.h
+52
-0
oneflow/api/python/framework/tensor_functions.cpp
oneflow/api/python/framework/tensor_functions.cpp
+970
-0
No files found.
Too many changes to show.
To preserve performance only
556 of 556+
files are displayed.
Plain diff
Email patch
oneflow/api/python/framework/framework.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <string>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
#include "oneflow/api/python/framework/framework.h"
#include "oneflow/core/framework/load_library.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
m
.
def
(
"RegisterGlobalForeignCallback"
,
&
RegisterGlobalForeignCallback
);
m
.
def
(
"DestroyGlobalForeignCallback"
,
&
DestroyGlobalForeignCallback
);
m
.
def
(
"RegisterGlobalWatcher"
,
&
RegisterGlobalWatcher
);
m
.
def
(
"LaunchJob"
,
&
LaunchJob
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"GetSerializedInterUserJobInfo"
,
[]()
->
Maybe
<
py
::
bytes
>
{
return
py
::
bytes
(
*
JUST
(
GetSerializedInterUserJobInfo
()));
});
m
.
def
(
"GetSerializedJobSet"
,
[]()
->
Maybe
<
py
::
bytes
>
{
return
py
::
bytes
(
*
JUST
(
GetSerializedJobSet
()));
});
m
.
def
(
"GetSerializedStructureGraph"
,
&
GetSerializedStructureGraph
/* a prototxt saved to file*/
);
m
.
def
(
"GetSerializedCurrentJob"
,
[]()
->
Maybe
<
py
::
bytes
>
{
return
py
::
bytes
(
*
JUST
(
GetSerializedCurrentJob
()));
});
m
.
def
(
"GetFunctionConfigDef"
,
&
GetFunctionConfigDef
);
m
.
def
(
"GetScopeConfigDef"
,
&
GetScopeConfigDef
);
m
.
def
(
"GetMachine2DeviceIdListOFRecordFromParallelConf"
,
&
GetSerializedMachineId2DeviceIdListOFRecord
);
m
.
def
(
"LoadSavedModel"
,
[](
const
std
::
string
&
saved_model_meta_file
,
bool
is_prototxt_file
)
->
Maybe
<
py
::
bytes
>
{
return
py
::
bytes
(
*
JUST
(
LoadSavedModel
(
saved_model_meta_file
,
is_prototxt_file
)));
});
m
.
def
(
"EagerExecutionEnabled"
,
EagerExecutionEnabled
);
m
.
def
(
"LoadLibrary"
,
&
LoadLibrary
);
}
}
// namespace oneflow
oneflow/api/python/framework/framework.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_FRAMEWORK_H_
#define ONEFLOW_API_PYTHON_FRAMEWORK_FRAMEWORK_H_
#include <string>
#include <google/protobuf/text_format.h>
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/inter_user_job_info.pb.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/job/foreign_watcher.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/oneflow.h"
#include "oneflow/core/job/placement.pb.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/framework/load_library.h"
#include "oneflow/core/serving/saved_model.pb.h"
namespace
oneflow
{
inline
Maybe
<
void
>
RegisterGlobalForeignCallback
(
const
std
::
shared_ptr
<
ForeignCallback
>&
callback
)
{
CHECK_ISNULL_OR_RETURN
(
Singleton
<
std
::
shared_ptr
<
ForeignCallback
>>::
Get
())
<<
"foreign callback registered"
;
// Singleton<T>::SetAllocated is preferred since Singleton<T>::New will output logs but
// glog is not constructed yet.
Singleton
<
std
::
shared_ptr
<
ForeignCallback
>>::
SetAllocated
(
new
std
::
shared_ptr
<
ForeignCallback
>
(
callback
));
return
Maybe
<
void
>::
Ok
();
}
inline
Maybe
<
void
>
DestroyGlobalForeignCallback
()
{
if
(
Singleton
<
std
::
shared_ptr
<
ForeignCallback
>>::
Get
())
{
Singleton
<
std
::
shared_ptr
<
ForeignCallback
>>::
Delete
();
}
return
Maybe
<
void
>::
Ok
();
}
inline
Maybe
<
void
>
RegisterGlobalWatcher
(
const
std
::
shared_ptr
<
ForeignWatcher
>&
watcher
)
{
CHECK_ISNULL_OR_RETURN
(
Singleton
<
std
::
shared_ptr
<
ForeignWatcher
>>::
Get
())
<<
"foreign watcher registered"
;
// Singleton<T>::SetAllocated is preferred since Singleton<T>::New will output logs but
// glog is not constructed yet.
Singleton
<
std
::
shared_ptr
<
ForeignWatcher
>>::
SetAllocated
(
new
std
::
shared_ptr
<
ForeignWatcher
>
(
watcher
));
return
Maybe
<
void
>::
Ok
();
}
inline
Maybe
<
void
>
LaunchJob
(
const
std
::
shared_ptr
<
oneflow
::
JobInstance
>&
cb
)
{
CHECK_OR_RETURN
(
GlobalProcessCtx
::
IsThisProcessMaster
());
CHECK_NOTNULL_OR_RETURN
(
Singleton
<
Oneflow
>::
Get
());
const
auto
&
job_name
=
cb
->
job_name
();
auto
*
buffer_mgr
=
Singleton
<
BufferMgr
<
std
::
shared_ptr
<
JobInstance
>>>::
Get
();
int64_t
job_id
=
Singleton
<
JobName2JobId
>::
Get
()
->
at
(
job_name
);
if
(
IsPullJob
(
job_name
,
*
Singleton
<
InterUserJobInfo
>::
Get
()))
{
buffer_mgr
->
Get
(
GetForeignOutputBufferName
(
job_name
))
->
Push
(
cb
);
}
if
(
IsPushJob
(
job_name
,
*
Singleton
<
InterUserJobInfo
>::
Get
()))
{
buffer_mgr
->
Get
(
GetForeignInputBufferName
(
job_name
))
->
Push
(
cb
);
}
buffer_mgr
->
Get
(
GetCallbackNotifierBufferName
(
job_name
))
->
Push
(
cb
);
Singleton
<
BufferMgr
<
int64_t
>>::
Get
()
->
Get
(
kBufferNameGlobalWaitJobId
)
->
Push
(
job_id
);
return
Maybe
<
void
>::
Ok
();
}
inline
Maybe
<
std
::
string
>
GetSerializedStructureGraph
()
{
const
auto
*
job_ctx_mgr
=
Singleton
<
LazyJobBuildAndInferCtxMgr
>::
Get
();
CHECK_NOTNULL_OR_RETURN
(
job_ctx_mgr
);
return
job_ctx_mgr
->
structure_graph
();
}
inline
Maybe
<
std
::
string
>
GetSerializedInterUserJobInfo
()
{
CHECK_OR_RETURN
(
GlobalProcessCtx
::
IsThisProcessMaster
());
CHECK_NOTNULL_OR_RETURN
(
Singleton
<
Oneflow
>::
Get
());
CHECK_NOTNULL_OR_RETURN
(
Singleton
<
InterUserJobInfo
>::
Get
());
return
Singleton
<
InterUserJobInfo
>::
Get
()
->
SerializeAsString
();
}
inline
Maybe
<
const
JobSet
&>
GetJobSet
()
{
auto
*
job_ctx_mgr
=
JUST
(
GlobalJobBuildAndInferCtxMgr
());
CHECK_NOTNULL_OR_RETURN
(
job_ctx_mgr
);
return
job_ctx_mgr
->
job_set
();
}
inline
Maybe
<
std
::
string
>
GetSerializedJobSet
()
{
return
JUST
(
GetJobSet
()).
SerializeAsString
();
}
inline
Maybe
<
std
::
string
>
GetSerializedCurrentJob
()
{
auto
*
job_ctx_mgr
=
Singleton
<
LazyJobBuildAndInferCtxMgr
>::
Get
();
CHECK_NOTNULL_OR_RETURN
(
job_ctx_mgr
);
auto
*
job_ctx
=
JUST
(
job_ctx_mgr
->
FindJobBuildAndInferCtx
(
*
JUST
(
job_ctx_mgr
->
GetCurrentJobName
())));
CHECK_NOTNULL_OR_RETURN
(
job_ctx
);
return
job_ctx
->
job
().
SerializeAsString
();
}
inline
Maybe
<
std
::
string
>
GetFunctionConfigDef
()
{
std
::
string
ret
;
google
::
protobuf
::
TextFormat
::
PrintToString
(
GlobalFunctionConfigDef
(),
&
ret
);
return
ret
;
}
inline
Maybe
<
std
::
string
>
GetScopeConfigDef
()
{
std
::
string
ret
;
google
::
protobuf
::
TextFormat
::
PrintToString
(
GlobalScopeConfigDef
(),
&
ret
);
return
ret
;
}
inline
Maybe
<
std
::
string
>
GetSerializedMachineId2DeviceIdListOFRecord
(
const
std
::
string
&
parallel_conf_str
)
{
ParallelConf
parallel_conf
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
parallel_conf_str
,
&
parallel_conf
))
<<
"parallel conf parse failed"
;
return
PbMessage2TxtString
(
*
JUST
(
ParseMachineAndDeviceIdList
(
parallel_conf
)));
}
inline
Maybe
<
std
::
string
>
LoadSavedModel
(
const
std
::
string
&
saved_model_meta_file
,
bool
is_prototxt_file
)
{
SavedModel
saved_model_proto
;
if
(
is_prototxt_file
)
{
CHECK_OR_RETURN
(
TryParseProtoFromTextFile
(
saved_model_meta_file
,
&
saved_model_proto
));
}
else
{
CHECK_OR_RETURN
(
TryParseProtoFromPbFile
(
saved_model_meta_file
,
&
saved_model_proto
));
}
return
saved_model_proto
.
SerializeAsString
();
}
inline
Maybe
<
void
>
LoadLibraryNow
(
const
std
::
string
&
lib_path
)
{
return
LoadLibrary
(
lib_path
);
}
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_FRAMEWORK_FRAMEWORK_H_
oneflow/api/python/framework/id_util.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/id_util.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
m
.
def
(
"UniqueStr"
,
&
UniqueStr
);
}
}
// namespace oneflow
oneflow/api/python/framework/instructions_builder.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <functional>
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
{
Maybe
<
void
>
DeprecatedPhysicalRun
(
const
std
::
function
<
void
(
InstructionsBuilder
*
)
>&
Build
)
{
return
PhysicalRun
([
&
](
InstructionsBuilder
*
instruction_builder
)
->
Maybe
<
void
>
{
Build
(
instruction_builder
);
return
Maybe
<
void
>::
Ok
();
});
}
}
// namespace
ONEFLOW_API_PYBIND11_MODULE
(
"deprecated"
,
m
)
{
py
::
class_
<
InstructionsBuilder
,
std
::
shared_ptr
<
InstructionsBuilder
>>
(
m
,
"InstructionsBuilder"
)
.
def
(
"BuildInitialScope"
,
[](
const
std
::
shared_ptr
<
InstructionsBuilder
>&
builder
,
int64_t
session_id
,
const
std
::
string
&
job_conf_str
,
const
std
::
string
&
device_tag
,
const
std
::
vector
<
std
::
string
>&
machine_device_ids
,
const
std
::
shared_ptr
<
Shape
>&
hierarchy
,
bool
is_mirrored
)
->
Maybe
<
Scope
>
{
JobConfigProto
job_conf
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
job_conf_str
,
&
job_conf
))
<<
Error
::
RuntimeError
()
<<
"job conf parse failed"
;
return
builder
->
BuildInitialScope
(
session_id
,
job_conf
,
device_tag
,
machine_device_ids
,
hierarchy
,
is_mirrored
);
},
py
::
arg
(
"session_id"
).
none
(
false
),
py
::
arg
(
"job_conf_str"
).
none
(
false
),
py
::
arg
(
"device_tag"
).
none
(
false
),
py
::
arg
(
"machine_device_ids"
).
none
(
false
),
py
::
arg
(
"hierarchy"
).
none
(
true
),
py
::
arg
(
"is_mirrored"
).
none
(
false
))
.
def
(
"BuildInitialScopeWithPlacement"
,
[](
const
std
::
shared_ptr
<
InstructionsBuilder
>&
builder
,
int64_t
session_id
,
const
std
::
string
&
job_conf_str
,
Symbol
<
ParallelDesc
>
placement
,
bool
is_mirrored
)
->
Maybe
<
Scope
>
{
JobConfigProto
job_conf
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
job_conf_str
,
&
job_conf
))
<<
Error
::
RuntimeError
()
<<
"job conf parse failed"
;
return
builder
->
BuildInitialScopeWithPlacement
(
session_id
,
job_conf
,
placement
,
is_mirrored
);
},
py
::
arg
(
"session_id"
).
none
(
false
),
py
::
arg
(
"job_conf_str"
).
none
(
false
),
py
::
arg
(
"placement"
).
none
(
false
),
py
::
arg
(
"is_mirrored"
).
none
(
false
))
.
def
(
"BuildScopeWithNewParallelDesc"
,
&
InstructionsBuilder
::
BuildScopeWithNewParallelDesc
,
py
::
arg
(
"scope"
).
none
(
false
),
py
::
arg
(
"device_tag"
).
none
(
false
),
py
::
arg
(
"machine_device_ids"
).
none
(
false
),
py
::
arg
(
"hierarchy"
).
none
(
true
))
.
def
(
"BuildScopeWithNewParallelConf"
,
[](
const
std
::
shared_ptr
<
InstructionsBuilder
>&
builder
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
string
&
parallel_conf_str
)
->
Maybe
<
Scope
>
{
ParallelConf
parallel_conf
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
parallel_conf_str
,
&
parallel_conf
))
<<
Error
::
RuntimeError
()
<<
"parallel conf parse failed"
;
return
builder
->
BuildScopeWithNewParallelConf
(
scope
,
parallel_conf
);
})
.
def
(
"BuildScopeWithNewIsMirrored"
,
&
InstructionsBuilder
::
BuildScopeWithNewIsMirrored
)
.
def
(
"BuildScopeWithNewScopeName"
,
&
InstructionsBuilder
::
BuildScopeWithNewScopeName
)
.
def
(
"BuildScopeByProtoStrSetter"
,
&
InstructionsBuilder
::
BuildScopeByProtoStrSetter
);
m
.
def
(
"PhysicalRun"
,
&
DeprecatedPhysicalRun
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
}
}
// namespace oneflow
oneflow/api/python/framework/job_instance.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <string>
#include <memory>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/job_instance.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
class
PyJobInstance
:
public
JobInstance
{
public:
// Inherit the constructors
using
JobInstance
::
JobInstance
;
// Trampoline (need one for each virtual function)
std
::
string
job_name
()
const
override
{
PYBIND11_OVERRIDE
(
std
::
string
,
/* Return type */
JobInstance
,
/* Parent class */
job_name
,
/* Name of function in C++ (must match Python name) */
);
}
std
::
string
sole_input_op_name_in_user_job
()
const
override
{
PYBIND11_OVERRIDE
(
std
::
string
,
JobInstance
,
sole_input_op_name_in_user_job
,
);
}
std
::
string
sole_output_op_name_in_user_job
()
const
override
{
PYBIND11_OVERRIDE
(
std
::
string
,
JobInstance
,
sole_output_op_name_in_user_job
,
);
}
void
PushBlob
(
uint64_t
ofblob_ptr
)
const
override
{
PYBIND11_OVERRIDE
(
void
,
JobInstance
,
PushBlob
,
ofblob_ptr
);
}
void
PullBlob
(
uint64_t
ofblob_ptr
)
const
override
{
PYBIND11_OVERRIDE
(
void
,
JobInstance
,
PullBlob
,
ofblob_ptr
);
}
void
Finish
()
const
override
{
PYBIND11_OVERRIDE
(
void
,
JobInstance
,
Finish
,
);
}
};
}
// namespace oneflow
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
using
namespace
oneflow
;
py
::
class_
<
JobInstance
,
PyJobInstance
,
std
::
shared_ptr
<
JobInstance
>>
(
m
,
"JobInstance"
)
.
def
(
py
::
init
<>
())
.
def
(
"job_name"
,
&
JobInstance
::
job_name
)
.
def
(
"sole_input_op_name_in_user_job"
,
&
JobInstance
::
sole_input_op_name_in_user_job
)
.
def
(
"sole_output_op_name_in_user_job"
,
&
JobInstance
::
sole_output_op_name_in_user_job
)
.
def
(
"PushBlob"
,
&
JobInstance
::
PushBlob
)
.
def
(
"PullBlob"
,
&
JobInstance
::
PullBlob
)
.
def
(
"Finish"
,
&
JobInstance
::
Finish
);
}
oneflow/api/python/framework/nn_graph.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include <string>
#include "oneflow/api/python/job_build/job_build_and_infer.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/nn_graph.h"
#include "oneflow/core/job/runtime.h"
#include "oneflow/core/register/blob.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job/job_ir.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
{
Maybe
<
py
::
object
>
APINNGraphAdditionalVarNames
(
const
std
::
shared_ptr
<
NNGraph
>&
graph
)
{
const
auto
names
=
*
JUST
(
graph
->
GetAdditionalVarOpNames
());
py
::
list
name_list
=
py
::
cast
(
names
);
return
py
::
cast
<
py
::
object
>
(
name_list
);
}
Maybe
<
py
::
object
>
APINNGraphAdditionalVarTensors
(
const
std
::
shared_ptr
<
NNGraph
>&
graph
)
{
const
auto
tensors
=
*
JUST
(
graph
->
GetAdditionalVarOpTensors
());
py
::
list
tensor_list
=
py
::
cast
(
tensors
);
return
py
::
cast
<
py
::
object
>
(
tensor_list
);
}
Maybe
<
py
::
bytes
>
APINNGraphGetCurrentSerializedJob
(
const
std
::
shared_ptr
<
NNGraph
>&
graph
)
{
const
auto
job
=
graph
->
job
();
return
py
::
bytes
(
job
.
SerializeAsString
());
}
}
// namespace
ONEFLOW_API_PYBIND11_MODULE
(
"nn.graph."
,
m
)
{
using
namespace
oneflow
;
py
::
class_
<
NNGraph
,
std
::
shared_ptr
<
NNGraph
>>
(
m
,
"CNNGraph"
)
.
def
(
py
::
init
([](
const
std
::
string
&
name
,
const
std
::
string
&
serialized_job
,
int64_t
job_id
,
const
std
::
shared_ptr
<
MultiClientSessionContext
>&
session_ctx
)
{
Job
job
;
if
(
!
job
.
ParseFromString
(
serialized_job
))
{
PyErr_SetString
(
PyExc_TypeError
,
"the second argument is not a valid job"
);
}
return
std
::
make_shared
<
NNGraph
>
(
name
,
job
,
job_id
,
session_ctx
);
}))
.
def_property_readonly
(
"name"
,
&
NNGraph
::
job_name
)
.
def_property
(
"job"
,
/*getter*/
[](
const
NNGraph
&
nn_graph
)
{
return
py
::
bytes
(
nn_graph
.
job
().
SerializeAsString
());
},
/*setter*/
[](
NNGraph
&
nn_graph
,
const
std
::
string
&
serialized_job
)
{
Job
job
;
if
(
!
job
.
ParseFromString
(
serialized_job
))
{
PyErr_SetString
(
PyExc_TypeError
,
"the value is not a valid job"
);
}
nn_graph
.
restore_job
(
job
);
})
.
def_property
(
"job_id"
,
&
NNGraph
::
job_id
,
[](
NNGraph
&
nn_graph
,
int64_t
job_id
)
{
nn_graph
.
restore_job_id
(
job_id
);
})
.
def
(
"register_input_op_names_and_tensors"
,
&
NNGraph
::
RegisterInputOpNamesAndTensors
)
.
def
(
"register_output_op_names_and_tensors"
,
&
NNGraph
::
RegisterOutputOpNamesAndTensors
)
.
def
(
"register_variable_op_names_and_tensors"
,
&
NNGraph
::
RegisterVariableOpNamesAndTensors
)
.
def
(
"register_additional_variable_names_and_tensors"
,
&
NNGraph
::
RegisterAdditionalVarOpNamesAndTensorsToBeLoaded
)
.
def_property_readonly
(
"additional_var_names"
,
&
APINNGraphAdditionalVarNames
)
.
def_property_readonly
(
"additional_var_tensors"
,
&
APINNGraphAdditionalVarTensors
)
.
def
(
"complie_and_init_runtime"
,
&
NNGraph
::
CompileAndInitRuntime
)
.
def
(
"get_current_job_str"
,
&
APINNGraphGetCurrentSerializedJob
);
m
.
def
(
"RunLazyNNGraph"
,
&
RunLazyNNGraph
);
m
.
def
(
"SoftSyncNNGraphBuffers"
,
&
SoftSyncNNGraphBuffers
);
m
.
def
(
"AddTensorAsGraphLoss"
,
&
AddTensorAsGraphLoss
);
m
.
def
(
"ConvertJobToTosaIR"
,
[](
const
std
::
string
&
serialized_job
)
->
Maybe
<
std
::
string
>
{
Job
job
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
serialized_job
,
&
job
))
<<
"serialized job conversion failed."
;
return
ConvertJobToTosaIR
(
&
job
);
});
m
.
def
(
"SaveJobToIR"
,
[](
const
std
::
string
&
serialized_job
,
const
std
::
string
&
path
)
->
Maybe
<
void
>
{
Job
job
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
serialized_job
,
&
job
))
<<
"serialized job conversion failed."
;
return
SaveJobToIR
(
&
job
,
path
);
});
m
.
def
(
"LoadSerializedJobFromIR"
,
[](
const
std
::
string
&
path
)
->
Maybe
<
py
::
bytes
>
{
Job
job
;
JUST
(
LoadJobFromIR
(
&
job
,
path
));
return
py
::
bytes
(
job
.
SerializeAsString
());
});
}
}
// namespace oneflow
oneflow/api/python/framework/one_embedding.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/embedding/embedding_manager.h"
#include "oneflow/core/embedding/persistent_table.h"
#include "oneflow/core/embedding/hash_functions.cuh"
#include "oneflow/core/framework/dtype.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
class
OneEmbeddingHandler
final
{
public:
OneEmbeddingHandler
(
const
std
::
string
&
key_value_store_option_string
,
int64_t
local_rank_id
,
int64_t
rank_id
,
int64_t
world_size
)
:
local_rank_id_
(
local_rank_id
),
rank_id_
(
rank_id
),
world_size_
(
world_size
)
{
embedding
::
KeyValueStoreOptions
key_value_store_options
(
key_value_store_option_string
);
embedding_name_
=
key_value_store_options
.
Name
();
CreateKeyValueStore
(
key_value_store_options
);
}
void
LoadSnapshot
(
const
std
::
string
&
snapshot_name
)
{
#if defined(WITH_CUDA) || defined(WITH_ROCM)
Singleton
<
embedding
::
EmbeddingManager
>::
Get
()
->
LoadSnapshot
(
embedding_name_
,
local_rank_id_
,
rank_id_
,
snapshot_name
);
#else
UNIMPLEMENTED
()
<<
"Only Support with CUDA"
;
#endif
}
void
SaveSnapshot
(
const
std
::
string
&
snapshot_name
)
{
#if defined(WITH_CUDA) || defined(WITH_ROCM)
Singleton
<
embedding
::
EmbeddingManager
>::
Get
()
->
SaveSnapshot
(
embedding_name_
,
local_rank_id_
,
rank_id_
,
snapshot_name
);
#else
UNIMPLEMENTED
()
<<
"Only Support with CUDA"
;
#endif
}
private:
void
CreateKeyValueStore
(
const
embedding
::
KeyValueStoreOptions
&
key_value_store_options
)
{
#if defined(WITH_CUDA) || defined(WITH_ROCM)
Singleton
<
embedding
::
EmbeddingManager
>::
Get
()
->
CreateKeyValueStore
(
key_value_store_options
,
local_rank_id_
,
rank_id_
,
world_size_
);
#else
UNIMPLEMENTED
()
<<
"Only Support with CUDA"
;
#endif
}
std
::
string
embedding_name_
;
int64_t
local_rank_id_
;
int64_t
rank_id_
;
int64_t
world_size_
;
};
namespace
embedding
{
class
PersistentTableWriter
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
PersistentTableWriter
);
PersistentTableWriter
()
=
default
;
virtual
~
PersistentTableWriter
()
=
default
;
virtual
void
Write
(
const
py
::
array
&
keys
,
const
py
::
array
&
values
)
=
0
;
virtual
void
Close
()
=
0
;
};
template
<
typename
Key
,
typename
Value
>
class
PersistentTableWriterImpl
:
public
PersistentTableWriter
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
PersistentTableWriterImpl
);
PersistentTableWriterImpl
(
const
std
::
vector
<
std
::
string
>&
paths
,
const
std
::
string
&
snapshot_name
,
uint32_t
storage_dim
,
uint64_t
target_chunk_size_mb
,
uint16_t
physical_block_size
)
:
closed_
(
false
),
snapshot_name_
(
snapshot_name
),
storage_dim_
(
storage_dim
)
{
tables_
.
resize
(
paths
.
size
());
for
(
size_t
i
=
0
;
i
<
paths
.
size
();
++
i
)
{
PersistentTableOptions
options
;
options
.
path
=
paths
[
i
];
options
.
key_size
=
sizeof
(
Key
);
options
.
value_size
=
storage_dim
*
sizeof
(
Value
);
options
.
target_chunk_size_mb
=
target_chunk_size_mb
;
options
.
physical_block_size
=
physical_block_size
;
tables_
[
i
]
=
NewPersistentTable
(
options
);
}
}
~
PersistentTableWriterImpl
()
override
{
CloseImpl
();
}
void
Write
(
const
py
::
array
&
keys
,
const
py
::
array
&
values
)
override
{
pybind11
::
dtype
::
of
<
int32_t
>
().
equal
(
pybind11
::
dtype
::
of
<
int64_t
>
());
CHECK
(
!
closed_
)
<<
"Write on closed table"
;
CHECK_EQ
(
keys
.
ndim
(),
1
);
CHECK_EQ
(
values
.
ndim
(),
2
);
CHECK_EQ
(
keys
.
shape
(
0
),
values
.
shape
(
0
));
CHECK_EQ
(
values
.
shape
(
1
),
storage_dim_
);
CHECK
(
keys
.
dtype
().
equal
(
py
::
dtype
::
of
<
Key
>
()));
CHECK
(
values
.
dtype
().
equal
(
py
::
dtype
::
of
<
Value
>
()));
const
size_t
n
=
keys
.
size
();
std
::
vector
<
std
::
vector
<
Key
>>
keys_buffers
(
tables_
.
size
());
std
::
vector
<
std
::
vector
<
char
>>
values_buffers
(
tables_
.
size
());
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
const
Key
key
=
*
(
reinterpret_cast
<
const
Key
*>
(
keys
.
template
data
(
i
)));
const
uint32_t
shard
=
ShardingHash
()(
key
)
%
tables_
.
size
();
keys_buffers
[
shard
].
push_back
(
key
);
const
size_t
values_offset
=
values_buffers
[
shard
].
size
();
values_buffers
[
shard
].
resize
(
values_offset
+
storage_dim_
*
sizeof
(
Value
));
for
(
size_t
j
=
0
;
j
<
values
.
shape
(
1
);
++
j
)
{
std
::
memcpy
(
values_buffers
[
shard
].
data
()
+
values_offset
+
j
*
values
.
itemsize
(),
values
.
template
data
(
i
,
j
),
values
.
itemsize
());
}
}
for
(
size_t
shard
=
0
;
shard
<
tables_
.
size
();
++
shard
)
{
tables_
[
shard
]
->
Put
(
keys_buffers
[
shard
].
size
(),
keys_buffers
[
shard
].
data
(),
values_buffers
[
shard
].
data
());
}
}
void
Close
()
override
{
CloseImpl
();
}
private:
void
CloseImpl
()
{
if
(
!
closed_
)
{
for
(
auto
&
table
:
tables_
)
{
table
->
SaveSnapshot
(
snapshot_name_
);
table
.
reset
();
}
}
closed_
=
true
;
}
bool
closed_
;
std
::
string
snapshot_name_
;
std
::
vector
<
std
::
unique_ptr
<
PersistentTable
>>
tables_
;
uint32_t
storage_dim_
;
};
template
<
typename
Key
>
std
::
shared_ptr
<
PersistentTableWriter
>
NewPersistentTableWriter
(
const
std
::
vector
<
std
::
string
>&
paths
,
const
std
::
string
&
snapshot_name
,
const
Symbol
<
DType
>&
key_type
,
const
Symbol
<
DType
>&
value_type
,
uint32_t
storage_dim
,
uint64_t
target_chunk_size_mb
,
uint16_t
physical_block_size
)
{
if
(
value_type
->
data_type
()
==
DataType
::
kFloat
)
{
return
std
::
shared_ptr
<
PersistentTableWriter
>
(
new
PersistentTableWriterImpl
<
Key
,
float
>
(
paths
,
snapshot_name
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
));
}
else
{
UNIMPLEMENTED
();
}
}
std
::
shared_ptr
<
PersistentTableWriter
>
NewPersistentTableWriter
(
const
std
::
vector
<
std
::
string
>&
paths
,
const
std
::
string
&
snapshot_name
,
const
Symbol
<
DType
>&
key_type
,
const
Symbol
<
DType
>&
value_type
,
uint32_t
storage_dim
,
uint64_t
target_chunk_size_mb
,
uint16_t
physical_block_size
)
{
if
(
key_type
->
data_type
()
==
DataType
::
kInt32
)
{
return
NewPersistentTableWriter
<
int32_t
>
(
paths
,
snapshot_name
,
key_type
,
value_type
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
);
}
else
if
(
key_type
->
data_type
()
==
DataType
::
kUInt32
)
{
return
NewPersistentTableWriter
<
uint32_t
>
(
paths
,
snapshot_name
,
key_type
,
value_type
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
);
}
else
if
(
key_type
->
data_type
()
==
DataType
::
kInt64
)
{
return
NewPersistentTableWriter
<
int64_t
>
(
paths
,
snapshot_name
,
key_type
,
value_type
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
);
}
else
if
(
key_type
->
data_type
()
==
DataType
::
kUInt64
)
{
return
NewPersistentTableWriter
<
uint64_t
>
(
paths
,
snapshot_name
,
key_type
,
value_type
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
);
}
else
{
UNIMPLEMENTED
();
return
std
::
shared_ptr
<
embedding
::
PersistentTableWriter
>
(
nullptr
);
}
}
class
PersistentTableReader
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
PersistentTableReader
);
PersistentTableReader
()
=
default
;
virtual
~
PersistentTableReader
()
=
default
;
virtual
std
::
tuple
<
py
::
object
,
py
::
object
>
Next
()
=
0
;
virtual
void
Close
()
=
0
;
};
template
<
typename
Key
,
typename
Value
>
class
PersistentTableReaderImpl
:
public
PersistentTableReader
{
public:
constexpr
static
uint32_t
kBatchSize
=
65536
;
OF_DISALLOW_COPY_AND_MOVE
(
PersistentTableReaderImpl
);
PersistentTableReaderImpl
(
const
std
::
vector
<
std
::
string
>&
paths
,
const
std
::
string
&
snapshot_name
,
uint32_t
storage_dim
,
uint64_t
target_chunk_size_mb
,
uint16_t
physical_block_size
)
:
closed_
(
false
),
snapshot_name_
(
snapshot_name
),
storage_dim_
(
storage_dim
),
current_table_
(
0
)
{
tables_
.
resize
(
paths
.
size
());
iterators_
.
resize
(
paths
.
size
());
for
(
size_t
i
=
0
;
i
<
paths
.
size
();
++
i
)
{
PersistentTableOptions
options
;
options
.
path
=
paths
[
i
];
options
.
key_size
=
sizeof
(
Key
);
options
.
value_size
=
storage_dim
*
sizeof
(
Value
);
options
.
target_chunk_size_mb
=
target_chunk_size_mb
;
options
.
physical_block_size
=
physical_block_size
;
tables_
[
i
]
=
NewPersistentTable
(
options
);
iterators_
[
i
]
=
std
::
unique_ptr
<
PersistentTable
::
Iterator
>
(
tables_
[
i
]
->
ReadSnapshot
(
snapshot_name
));
}
keys_buffer_
.
resize
(
kBatchSize
);
values_buffer_
.
resize
(
kBatchSize
*
storage_dim_
);
}
~
PersistentTableReaderImpl
()
override
{
CloseImpl
();
}
std
::
tuple
<
py
::
object
,
py
::
object
>
Next
()
override
{
while
(
current_table_
<
tables_
.
size
())
{
uint32_t
n_result
=
0
;
iterators_
[
current_table_
]
->
Next
(
kBatchSize
,
&
n_result
,
keys_buffer_
.
data
(),
values_buffer_
.
data
());
if
(
n_result
!=
0
)
{
py
::
array_t
<
Key
>
keys_arr
(
py
::
array
::
ShapeContainer
({
n_result
}));
py
::
array_t
<
Value
>
values_arr
(
py
::
array
::
ShapeContainer
({
n_result
,
storage_dim_
}));
std
::
memcpy
(
keys_arr
.
mutable_data
(),
keys_buffer_
.
data
(),
n_result
*
sizeof
(
Key
));
std
::
memcpy
(
values_arr
.
mutable_data
(),
values_buffer_
.
data
(),
n_result
*
storage_dim_
*
sizeof
(
Value
));
return
std
::
make_tuple
(
keys_arr
,
values_arr
);
}
else
{
current_table_
+=
1
;
continue
;
}
}
throw
py
::
stop_iteration
();
}
void
Close
()
override
{
CloseImpl
();
}
private:
void
CloseImpl
()
{
if
(
!
closed_
)
{
for
(
auto
&
table
:
tables_
)
{
table
.
reset
();
}
}
closed_
=
true
;
}
bool
closed_
;
std
::
string
snapshot_name_
;
std
::
vector
<
std
::
unique_ptr
<
PersistentTable
>>
tables_
;
std
::
vector
<
std
::
unique_ptr
<
PersistentTable
::
Iterator
>>
iterators_
;
uint32_t
storage_dim_
;
size_t
current_table_
;
std
::
vector
<
Key
>
keys_buffer_
;
std
::
vector
<
Value
>
values_buffer_
;
};
template
<
typename
Key
>
std
::
shared_ptr
<
PersistentTableReader
>
NewPersistentTableReader
(
const
std
::
vector
<
std
::
string
>&
paths
,
const
std
::
string
&
snapshot_name
,
const
Symbol
<
DType
>&
key_type
,
const
Symbol
<
DType
>&
value_type
,
uint32_t
storage_dim
,
uint64_t
target_chunk_size_mb
,
uint16_t
physical_block_size
)
{
if
(
value_type
->
data_type
()
==
DataType
::
kFloat
)
{
return
std
::
shared_ptr
<
PersistentTableReader
>
(
new
PersistentTableReaderImpl
<
Key
,
float
>
(
paths
,
snapshot_name
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
));
}
else
{
UNIMPLEMENTED
();
}
}
std
::
shared_ptr
<
PersistentTableReader
>
NewPersistentTableReader
(
const
std
::
vector
<
std
::
string
>&
paths
,
const
std
::
string
&
snapshot_name
,
const
Symbol
<
DType
>&
key_type
,
const
Symbol
<
DType
>&
value_type
,
uint32_t
storage_dim
,
uint64_t
target_chunk_size_mb
,
uint16_t
physical_block_size
)
{
if
(
key_type
->
data_type
()
==
DataType
::
kInt32
)
{
return
NewPersistentTableReader
<
int32_t
>
(
paths
,
snapshot_name
,
key_type
,
value_type
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
);
}
else
if
(
key_type
->
data_type
()
==
DataType
::
kUInt32
)
{
return
NewPersistentTableReader
<
uint32_t
>
(
paths
,
snapshot_name
,
key_type
,
value_type
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
);
}
else
if
(
key_type
->
data_type
()
==
DataType
::
kInt64
)
{
return
NewPersistentTableReader
<
int64_t
>
(
paths
,
snapshot_name
,
key_type
,
value_type
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
);
}
else
if
(
key_type
->
data_type
()
==
DataType
::
kUInt64
)
{
return
NewPersistentTableReader
<
uint64_t
>
(
paths
,
snapshot_name
,
key_type
,
value_type
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
);
}
else
{
UNIMPLEMENTED
();
return
std
::
shared_ptr
<
embedding
::
PersistentTableReader
>
(
nullptr
);
}
}
}
// namespace embedding
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
py
::
class_
<
OneEmbeddingHandler
,
std
::
shared_ptr
<
OneEmbeddingHandler
>>
(
m
,
"OneEmbeddingHandler"
)
.
def
(
py
::
init
([](
const
std
::
string
&
key_value_store_option_str
,
const
int64_t
local_rank_id
,
const
int64_t
rank_id
,
const
int64_t
world_size
)
{
return
std
::
make_shared
<
OneEmbeddingHandler
>
(
key_value_store_option_str
,
local_rank_id
,
rank_id
,
world_size
);
}))
.
def
(
"SaveSnapshot"
,
&
OneEmbeddingHandler
::
SaveSnapshot
)
.
def
(
"LoadSnapshot"
,
&
OneEmbeddingHandler
::
LoadSnapshot
);
py
::
class_
<
embedding
::
PersistentTableWriter
,
std
::
shared_ptr
<
embedding
::
PersistentTableWriter
>>
(
m
,
"PersistentTableWriter"
)
.
def
(
py
::
init
([](
const
std
::
vector
<
std
::
string
>&
paths
,
const
std
::
string
&
snapshot_name
,
const
Symbol
<
DType
>&
key_type
,
const
Symbol
<
DType
>&
value_type
,
uint32_t
storage_dim
,
uint64_t
target_chunk_size_mb
,
uint16_t
physical_block_size
)
{
return
embedding
::
NewPersistentTableWriter
(
paths
,
snapshot_name
,
key_type
,
value_type
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
);
}))
.
def
(
"__enter__"
,
[](
embedding
::
PersistentTableWriter
*
writer
)
{
return
writer
;
})
.
def
(
"__exit__"
,
[](
embedding
::
PersistentTableWriter
*
writer
,
const
py
::
object
&
exc_type
,
const
py
::
object
&
exc_val
,
const
py
::
object
&
exc_tb
)
{
writer
->
Close
();
})
.
def
(
"write"
,
&
embedding
::
PersistentTableWriter
::
Write
)
.
def
(
"close"
,
&
embedding
::
PersistentTableWriter
::
Close
);
py
::
class_
<
embedding
::
PersistentTableReader
,
std
::
shared_ptr
<
embedding
::
PersistentTableReader
>>
(
m
,
"PersistentTableReader"
)
.
def
(
py
::
init
([](
const
std
::
vector
<
std
::
string
>&
paths
,
const
std
::
string
&
snapshot_name
,
const
Symbol
<
DType
>&
key_type
,
const
Symbol
<
DType
>&
value_type
,
uint32_t
storage_dim
,
uint64_t
target_chunk_size_mb
,
uint16_t
physical_block_size
)
{
return
embedding
::
NewPersistentTableReader
(
paths
,
snapshot_name
,
key_type
,
value_type
,
storage_dim
,
target_chunk_size_mb
,
physical_block_size
);
}))
.
def
(
"__next__"
,
&
embedding
::
PersistentTableReader
::
Next
)
.
def
(
"__iter__"
,
[](
embedding
::
PersistentTableReader
*
reader
)
{
return
reader
;
})
.
def
(
"__enter__"
,
[](
embedding
::
PersistentTableReader
*
reader
)
{
return
reader
;
})
.
def
(
"__exit__"
,
[](
embedding
::
PersistentTableReader
*
reader
,
const
py
::
object
&
exc_type
,
const
py
::
object
&
exc_val
,
const
py
::
object
&
exc_tb
)
{
reader
->
Close
();
})
.
def
(
"close"
,
&
embedding
::
PersistentTableReader
::
Close
);
}
}
// namespace oneflow
oneflow/api/python/framework/op_builder.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/throw.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_builder.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
one
{
ONEFLOW_API_PYBIND11_MODULE
(
"one"
,
m
)
{
py
::
class_
<
one
::
OpBuilder
,
std
::
shared_ptr
<
one
::
OpBuilder
>>
(
m
,
"OpBuilder"
)
.
def
(
py
::
init
<
const
std
::
string
&>
())
.
def
(
py
::
init
<
const
std
::
string
&
,
const
std
::
string
&>
())
.
def
(
"input"
,
&
OpBuilder
::
MaybeInput
)
.
def
(
"output"
,
&
OpBuilder
::
MaybeOutput
)
.
def
(
"attr"
,
[](
const
std
::
shared_ptr
<
one
::
OpBuilder
>&
x
,
const
std
::
string
&
attr_name
,
const
std
::
string
&
attr_val_str
)
->
Maybe
<
OpBuilder
&>
{
AttrValue
attr_val
;
if
(
!
TxtString2PbMessage
(
attr_val_str
,
&
attr_val
))
{
THROW
(
RuntimeError
)
<<
"attr val parse failed.
\n
"
<<
attr_val_str
;
}
return
x
->
MaybeAttr
(
attr_name
,
attr_val
);
})
.
def
(
"build"
,
&
OpBuilder
::
Build
);
}
}
// namespace one
}
// namespace oneflow
oneflow/api/python/framework/op_expr.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/throw.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
{
template
<
typename
OpT
,
typename
ConfT
,
typename
std
::
enable_if
<
std
::
is_base_of
<
one
::
BuiltinOpExpr
,
OpT
>
::
value
>::
type
*
=
nullptr
>
py
::
class_
<
OpT
,
one
::
BuiltinOpExpr
,
std
::
shared_ptr
<
OpT
>>
PybindExportOpExpr
(
py
::
module
&
m
,
const
char
*
op_type_name
)
{
return
py
::
class_
<
OpT
,
one
::
BuiltinOpExpr
,
std
::
shared_ptr
<
OpT
>>
(
m
,
op_type_name
)
.
def
(
py
::
init
([](
const
std
::
string
&
op_name
,
const
std
::
string
&
op_conf_str
,
const
std
::
vector
<
std
::
string
>&
indexed_ibns
,
const
std
::
vector
<
std
::
string
>&
indexed_obns
)
{
ConfT
proto_op_conf
;
if
(
!
TxtString2PbMessage
(
op_conf_str
,
&
proto_op_conf
))
{
THROW
(
RuntimeError
)
<<
"op conf parse failed.
\n
"
<<
op_conf_str
;
}
return
OpT
::
New
(
op_name
,
std
::
move
(
proto_op_conf
),
indexed_ibns
,
indexed_obns
)
.
GetPtrOrThrow
();
}));
}
}
// namespace
ONEFLOW_API_PYBIND11_MODULE
(
"one"
,
m
)
{
py
::
class_
<
one
::
OpExpr
,
std
::
shared_ptr
<
one
::
OpExpr
>>
(
m
,
"OpExpr"
)
.
def_property_readonly
(
"op_type_name"
,
&
one
::
OpExpr
::
op_type_name
)
.
def_property_readonly
(
"input_size"
,
&
one
::
OpExpr
::
input_size
)
.
def_property_readonly
(
"output_size"
,
&
one
::
OpExpr
::
output_size
);
py
::
class_
<
one
::
BuiltinOpExpr
,
one
::
OpExpr
,
std
::
shared_ptr
<
one
::
BuiltinOpExpr
>>
(
m
,
"BuiltinOpExpr"
)
.
def_property_readonly
(
"name"
,
&
one
::
BuiltinOpExpr
::
op_name
)
.
def_property_readonly
(
"indexed_ibns"
,
&
one
::
BuiltinOpExpr
::
indexed_ibns
)
.
def_property_readonly
(
"indexed_obns"
,
&
one
::
BuiltinOpExpr
::
indexed_obns
);
auto
py_user_op_class
=
PybindExportOpExpr
<
one
::
UserOpExpr
,
UserOpConf
>
(
m
,
"UserOpExpr"
);
py_user_op_class
.
def_property_readonly
(
"op_type_name"
,
[](
const
one
::
UserOpExpr
&
op
)
{
return
op
.
proto
().
op_type_name
();
});
PybindExportOpExpr
<
one
::
VariableOpExpr
,
VariableOpConf
>
(
m
,
"VariableOpExpr"
);
// NOTE(chengcheng): export for Lazy nn.Graph Feed/Fetch EagerTensor to/from LazyTensor.
PybindExportOpExpr
<
one
::
FeedInputOpExpr
,
FeedInputOpConf
>
(
m
,
"FeedInputOpExpr"
);
PybindExportOpExpr
<
one
::
FeedVariableOpExpr
,
FeedVariableOpConf
>
(
m
,
"FeedVariableOpExpr"
);
PybindExportOpExpr
<
one
::
FetchOutputOpExpr
,
FetchOutputOpConf
>
(
m
,
"FetchOutputOpExpr"
);
PybindExportOpExpr
<
one
::
ImageDecoderRandomCropResizeOpExpr
,
ImageDecoderRandomCropResizeOpConf
>
(
m
,
"ImageDecoderRandomCropResizeOpExpr"
);
}
}
// namespace oneflow
oneflow/api/python/framework/parallel_conf_util.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/parallel_conf_util.h"
namespace
oneflow
{
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
m
.
def
(
"GetDeviceTagAndMachineDeviceIdsAndHierarchy"
,
&
GetDeviceTagAndMachineDeviceIdsAndHierarchy
);
m
.
def
(
"MakeParallelConf"
,
&
MakeParallelConf
);
}
}
// namespace oneflow
oneflow/api/python/framework/py_kernel_registry.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <string>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/extension/python/py_kernel_registry.h"
namespace
py
=
pybind11
;
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
m
.
def
(
"RegisterPyKernelCaller"
,
&::
oneflow
::
pyext
::
RegisterPyKernelCaller
);
m
.
def
(
"RegisterPyKernels"
,
[](
py
::
object
py_kernels
)
{
::
oneflow
::
pyext
::
RegisterPyKernels
(
py_kernels
.
ptr
());
});
}
oneflow/api/python/framework/random_generator.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/framework/tensor.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
Maybe
<
one
::
Generator
>
CreateGenerator
(
const
std
::
string
&
device_tag
)
{
std
::
string
device_name
=
""
;
int
device_index
=
-
1
;
JUST
(
ParsingDeviceTag
(
device_tag
,
&
device_name
,
&
device_index
));
return
one
::
MakeGenerator
(
device_name
,
device_index
);
}
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
py
::
class_
<
one
::
Generator
,
std
::
shared_ptr
<
one
::
Generator
>>
(
m
,
"Generator"
)
.
def
(
py
::
init
([](
const
std
::
string
&
device_tag
)
{
return
CreateGenerator
(
device_tag
).
GetPtrOrThrow
();
}))
.
def
(
"manual_seed"
,
[](
const
std
::
shared_ptr
<
one
::
Generator
>&
generator
,
const
py
::
object
&
seed
)
->
Maybe
<
void
>
{
int64_t
seed_val
=
JUST
(
one
::
functional
::
PyUnpackLong
(
seed
.
ptr
()));
generator
->
set_current_seed
(
seed_val
);
return
Maybe
<
void
>::
Ok
();
})
.
def
(
"initial_seed"
,
&
one
::
Generator
::
current_seed
)
.
def
(
"seed"
,
&
one
::
Generator
::
seed
)
.
def_property_readonly
(
"device"
,
&
one
::
Generator
::
device
)
.
def
(
"get_state"
,
&
one
::
Generator
::
GetState
)
.
def
(
"set_state"
,
&
one
::
Generator
::
SetState
);
m
.
def
(
"manual_seed"
,
[](
const
py
::
object
&
seed
)
->
Maybe
<
one
::
Generator
>
{
int64_t
seed_val
=
JUST
(
one
::
functional
::
PyUnpackLong
(
seed
.
ptr
()));
return
one
::
ManualSeed
(
seed_val
);
});
m
.
def
(
"manual_seed"
,
[](
const
py
::
object
&
seed
,
const
std
::
string
&
device
,
int
device_index
)
->
Maybe
<
void
>
{
int64_t
seed_val
=
JUST
(
one
::
functional
::
PyUnpackLong
(
seed
.
ptr
()));
return
one
::
ManualSeed
(
seed_val
,
device
,
device_index
);
});
m
.
def
(
"create_generator"
,
&
CreateGenerator
);
m
.
def
(
"default_generator"
,
[](
const
std
::
string
&
device_tag
)
->
Maybe
<
one
::
Generator
>
{
std
::
string
device_name
=
""
;
int
device_index
=
-
1
;
JUST
(
ParsingDeviceTag
(
device_tag
,
&
device_name
,
&
device_index
));
return
one
::
DefaultGenerator
(
device_name
,
device_index
);
});
m
.
def
(
"ManualSeedAllCudaGenerator"
,
[](
const
py
::
object
&
seed
)
->
Maybe
<
void
>
{
int64_t
seed_val
=
JUST
(
one
::
functional
::
PyUnpackLong
(
seed
.
ptr
()));
return
one
::
ManualSeedAllCudaGenerator
(
seed_val
);
});
}
}
// namespace oneflow
oneflow/api/python/framework/scope_util.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/scope_util.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
m
.
def
(
"GetCurrentScope"
,
&
GetCurrentScope
);
m
.
def
(
"MakeInitialScope"
,
[](
const
std
::
string
&
job_conf_str
,
Symbol
<
ParallelDesc
>
placement
,
bool
is_mirrored
)
->
Maybe
<
Scope
>
{
JobConfigProto
job_conf
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
job_conf_str
,
&
job_conf
))
<<
"job conf parse failed"
;
return
MakeInitialScope
(
job_conf
,
placement
,
is_mirrored
);
});
m
.
def
(
"InitGlobalScopeStack"
,
&
InitThreadLocalScopeStack
);
m
.
def
(
"GlobalScopeStackPush"
,
&
ThreadLocalScopeStackPush
);
m
.
def
(
"GlobalScopeStackPop"
,
&
ThreadLocalScopeStackPop
);
}
}
// namespace oneflow
oneflow/api/python/framework/session_util.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/session_util.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
py
::
class_
<
Session
,
std
::
shared_ptr
<
Session
>>
(
m
,
"Session"
)
.
def_property_readonly
(
"id"
,
&
Session
::
id
)
.
def
(
"push_mirrored_strategy_enabled"
,
&
Session
::
PushMirroredStrategyEnabled
)
.
def
(
"pop_mirrored_strategy_enabled"
,
&
Session
::
PopMirroredStrategyEnabled
)
.
def
(
"is_mirrored_strategy_enabled"
,
&
Session
::
IsMirroredStrategyEnabled
)
.
def
(
"is_consistent_strategy_enabled"
,
&
Session
::
IsConsistentStrategyEnabled
)
.
def
(
"is_mirrored_strategy_enabled_stack_size"
,
[](
const
Session
*
sess
)
{
return
sess
->
is_mirrored_strategy_enabled_stack
()
->
size
();
});
m
.
def
(
"GetDefaultSessionId"
,
&
GetDefaultSessionId
);
m
.
def
(
"RegsiterSession"
,
&
RegsiterSession
);
m
.
def
(
"GetDefaultSession"
,
&
GetDefaultSession
);
m
.
def
(
"ClearSessionById"
,
&
ClearSessionById
);
}
}
// namespace oneflow
oneflow/api/python/framework/shut_down_util.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/shut_down_util.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
m
.
def
(
"SetShuttingDown"
,
[]()
{
return
SetShuttingDown
();
});
}
}
// namespace oneflow
oneflow/api/python/framework/size.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/shape.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
using
one
::
functional
::
PyObjectPtr
;
static
PyObject
*
TensorSize_repr
(
TensorSize
*
self
)
{
std
::
stringstream
ss
;
int32_t
idx
=
0
;
int32_t
size
=
PyTuple_Size
((
PyObject
*
)
self
);
ss
<<
"oneflow.Size(["
;
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
int64_t
dim
=
PyLong_AsLongLong
(
PyTuple_GET_ITEM
(
self
,
i
));
ss
<<
dim
;
if
(
++
idx
!=
size
)
{
ss
<<
", "
;
}
}
ss
<<
"])"
;
return
PyUnicode_FromString
(
ss
.
str
().
c_str
());
}
static
PyObject
*
TensorSize_new
(
PyTypeObject
*
type
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
PyObjectPtr
self
(
PyTuple_Type
.
tp_new
(
type
,
args
,
kwargs
));
if
(
self
.
get
())
{
for
(
int
i
=
0
;
i
<
PyTuple_Size
(
self
.
get
());
++
i
)
{
PyObject
*
item
=
PyTuple_GET_ITEM
(
self
.
get
(),
i
);
if
(
!
PyLong_Check
(
item
))
{
return
PyErr_Format
(
PyExc_TypeError
,
"oneflow.Size() takes an iterable of 'int', but item '%d' is '%s'"
,
i
,
Py_TYPE
(
item
)
->
tp_name
);
}
}
}
return
self
.
release
();
}
static
Py_ssize_t
TensorSize_length
(
TensorSize
*
self
)
{
return
PyTuple_Type
.
tp_as_sequence
->
sq_length
((
PyObject
*
)
self
);
}
static
PyObject
*
TensorSize_concat
(
TensorSize
*
self
,
PyObject
*
other
)
{
PyObjectPtr
result
(
PyTuple_Type
.
tp_as_sequence
->
sq_concat
((
PyObject
*
)
self
,
other
));
if
(
!
result
.
get
())
{
return
nullptr
;
}
if
(
PyTuple_Check
(
result
.
get
()))
{
PyObjectPtr
args
(
PyTuple_Pack
(
1
,
result
.
get
()));
return
TensorSize_new
(
&
TensorSize_Type
,
args
.
get
(),
nullptr
);
}
return
result
.
release
();
}
static
PyObject
*
TensorSize_repeat
(
TensorSize
*
self
,
Py_ssize_t
n
)
{
PyObjectPtr
result
(
PyTuple_Type
.
tp_as_sequence
->
sq_repeat
((
PyObject
*
)
self
,
n
));
if
(
!
result
.
get
())
{
return
nullptr
;
}
if
(
PyTuple_Check
(
result
.
get
()))
{
PyObjectPtr
args
(
PyTuple_Pack
(
1
,
result
.
get
()));
return
TensorSize_new
(
&
TensorSize_Type
,
args
.
get
(),
nullptr
);
}
return
result
.
release
();
}
static
PyObject
*
TensorSize_item
(
TensorSize
*
self
,
Py_ssize_t
i
)
{
return
PyTuple_Type
.
tp_as_sequence
->
sq_item
((
PyObject
*
)
self
,
i
);
}
static
int
TensorSize_contains
(
TensorSize
*
self
,
PyObject
*
el
)
{
return
PyTuple_Type
.
tp_as_sequence
->
sq_contains
((
PyObject
*
)
self
,
el
);
}
static
PySequenceMethods
TensorSize_as_sequence
=
{
(
lenfunc
)
TensorSize_length
,
/* sq_length */
(
binaryfunc
)
TensorSize_concat
,
/* sq_concat */
(
ssizeargfunc
)
TensorSize_repeat
,
/* sq_repeat */
(
ssizeargfunc
)
TensorSize_item
,
/* sq_item */
0
,
/* sq_slice */
0
,
/* sq_ass_item */
0
,
/* sq_ass_slice */
(
objobjproc
)
TensorSize_contains
,
/* sq_contains */
};
static
PyObject
*
TensorSize_subscript
(
TensorSize
*
self
,
PyObject
*
item
)
{
PyObjectPtr
result
(
PyTuple_Type
.
tp_as_mapping
->
mp_subscript
((
PyObject
*
)
self
,
item
));
if
(
!
result
.
get
())
{
return
nullptr
;
}
if
(
PyTuple_Check
(
result
.
get
()))
{
PyObjectPtr
args
(
PyTuple_Pack
(
1
,
result
.
get
()));
return
TensorSize_new
(
&
TensorSize_Type
,
args
.
get
(),
nullptr
);
}
return
result
.
release
();
};
static
PyMappingMethods
TensorSize_as_mapping
=
{
(
lenfunc
)
TensorSize_length
,
/* mp_length */
(
binaryfunc
)
TensorSize_subscript
,
/* mp_subscript */
0
,
/* mp_ass_subscript */
};
static
PyObject
*
TensorSize_numel
(
PyObject
*
self
,
PyObject
*
args
)
{
int64_t
numel
=
1
;
for
(
int
i
=
0
;
i
<
PyTuple_Size
(
self
);
++
i
)
{
numel
*=
PyLong_AsLongLong
(
PyTuple_GET_ITEM
((
TensorSize
*
)
self
,
i
));
}
return
PyLong_FromLongLong
(
numel
);
}
static
PyMethodDef
TensorSize_methods
[]
=
{
{
"numel"
,
(
PyCFunction
)
TensorSize_numel
,
METH_NOARGS
,
NULL
},
{
NULL
}};
PyTypeObject
TensorSize_Type
=
{
PyVarObject_HEAD_INIT
(
NULL
,
0
)
"oneflow.Size"
,
/* tp_name */
sizeof
(
TensorSize
),
/* tp_basicsize */
0
,
/* tp_itemsize */
NULL
,
/* tp_dealloc */
0
,
/* tp_vectorcall_offset */
NULL
,
/* tp_getattr */
NULL
,
/* tp_setattr */
NULL
,
/* tp_reserved */
(
reprfunc
)
TensorSize_repr
,
/* tp_repr */
NULL
,
/* tp_as_number */
&
TensorSize_as_sequence
,
/* tp_as_sequence */
&
TensorSize_as_mapping
,
/* tp_as_mapping */
NULL
,
/* tp_hash */
NULL
,
/* tp_call */
NULL
,
/* tp_str */
NULL
,
/* tp_getattro */
NULL
,
/* tp_setattro */
NULL
,
/* tp_as_buffer */
Py_TPFLAGS_DEFAULT
|
Py_TPFLAGS_BASETYPE
,
/* tp_flags */
NULL
,
/* tp_doc */
NULL
,
/* tp_traverse */
NULL
,
/* tp_clear */
NULL
,
/* tp_richcompare */
0
,
/* tp_weaklistoffset */
NULL
,
/* tp_iter */
NULL
,
/* tp_iternext */
TensorSize_methods
,
/* tp_methods */
NULL
,
/* tp_members */
NULL
,
/* tp_getset */
&
PyTuple_Type
,
/* tp_base */
NULL
,
/* tp_dict */
NULL
,
/* tp_descr_get */
NULL
,
/* tp_descr_set */
0
,
/* tp_dictoffset */
NULL
,
/* tp_init */
NULL
,
/* tp_alloc */
TensorSize_new
,
/* tp_new */
NULL
,
/* tp_free */
};
int
TensorSize_Check
(
PyObject
*
p
)
{
return
p
&&
p
->
ob_type
==
&
TensorSize_Type
;
}
PyObject
*
TensorSize_New
(
Py_ssize_t
len
)
{
return
TensorSize_Type
.
tp_alloc
(
&
TensorSize_Type
,
len
);
}
PyObject
*
TensorSize_NewFromShape
(
const
Shape
&
size
)
{
PyObjectPtr
self
(
TensorSize_New
(
size
.
NumAxes
()));
if
(
self
.
get
())
{
for
(
int
i
=
0
;
i
<
size
.
NumAxes
();
++
i
)
{
PyTuple_SET_ITEM
(
self
.
get
(),
i
,
PyLong_FromLongLong
(
size
.
At
(
i
)));
}
}
return
self
.
release
();
}
Shape
TensorSize_AsShape
(
PyObject
*
self
)
{
if
(
!
TensorSize_Check
(
self
))
{
PyErr_Format
(
PyExc_TypeError
,
"can only convert TensorSize(not
\"
%s
\"
) to Shape"
,
Py_TYPE
(
self
)
->
tp_name
);
return
Shape
();
}
int
size
=
TensorSize_length
((
TensorSize
*
)
self
);
DimVector
dim_vec
(
size
);
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
dim_vec
[
i
]
=
PyLong_AsLongLong
(
PyTuple_GET_ITEM
((
TensorSize
*
)
self
,
i
));
}
return
Shape
(
std
::
move
(
dim_vec
));
}
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
if
(
PyType_Ready
(
&
TensorSize_Type
)
<
0
)
{
return
;
}
Py_INCREF
(
&
TensorSize_Type
);
if
(
PyModule_AddObject
(
m
.
ptr
(),
"Size"
,
(
PyObject
*
)
&
TensorSize_Type
)
<
0
)
{
return
;
}
}
}
// namespace oneflow
oneflow/api/python/framework/size.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_
#define ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_
#include <type_traits>
#include <Python.h>
#include <pybind11/pybind11.h>
#include "oneflow/core/common/shape.h"
namespace
oneflow
{
typedef
struct
{
PyTupleObject
ob_base
;
}
TensorSize
;
extern
PyTypeObject
TensorSize_Type
;
int
TensorSize_Check
(
PyObject
*
p
);
PyObject
*
TensorSize_New
(
Py_ssize_t
len
);
PyObject
*
TensorSize_NewFromShape
(
const
Shape
&
size
);
Shape
TensorSize_AsShape
(
PyObject
*
self
);
}
// namespace oneflow
PYBIND11_NAMESPACE_BEGIN
(
PYBIND11_NAMESPACE
)
class
shape
:
public
object
{
public:
PYBIND11_OBJECT_CVT
(
shape
,
object
,
oneflow
::
TensorSize_Check
,
raw_shape
)
explicit
shape
(
size_t
size
=
0
)
:
object
(
oneflow
::
TensorSize_New
((
ssize_t
)
size
),
stolen_t
{})
{
if
(
!
m_ptr
)
pybind11_fail
(
"Could not allocate tensor size object!"
);
}
size_t
size
()
const
{
return
(
size_t
)
PyTuple_Size
(
m_ptr
);
}
bool
empty
()
const
{
return
size
()
==
0
;
}
detail
::
tuple_accessor
operator
[](
size_t
index
)
const
{
return
{
*
this
,
index
};
}
detail
::
item_accessor
operator
[](
handle
h
)
const
{
return
object
::
operator
[](
h
);
}
detail
::
tuple_iterator
begin
()
const
{
return
{
*
this
,
0
};
}
detail
::
tuple_iterator
end
()
const
{
return
{
*
this
,
PyTuple_GET_SIZE
(
m_ptr
)};
}
private:
static
PyObject
*
raw_shape
(
PyObject
*
op
)
{
if
(
oneflow
::
TensorSize_Check
(
op
))
return
handle
(
op
).
inc_ref
().
ptr
();
return
PyObject_CallFunctionObjArgs
((
PyObject
*
)
&
oneflow
::
TensorSize_Type
,
op
,
NULL
);
}
};
PYBIND11_NAMESPACE_BEGIN
(
detail
)
template
<
typename
T
>
struct
shape_type_caster
{
public:
bool
load
(
handle
src
,
bool
convert
)
{
value_
=
nullptr
;
if
(
src
&&
src
.
is_none
())
{
return
true
;
}
if
(
!
oneflow
::
TensorSize_Check
(
src
.
ptr
()))
{
return
false
;
}
value_
=
std
::
make_shared
<
T
>
(
oneflow
::
TensorSize_AsShape
(
src
.
ptr
()));
return
true
;
}
template
<
typename
U
>
static
handle
cast
(
U
&&
src
,
return_value_policy
/*policy*/
,
handle
/*parent*/
)
{
return
cast_impl
(
std
::
forward
<
U
>
(
src
));
}
template
<
typename
U
>
static
handle
cast
(
U
*
src
,
return_value_policy
policy
,
handle
parent
)
{
if
(
!
src
)
{
return
none
().
release
();
}
return
cast
(
*
src
,
policy
,
parent
);
}
operator
T
*
()
{
return
value_
.
get
();
}
operator
T
&
()
{
return
*
value_
;
}
operator
T
&&
()
&&
{
return
std
::
move
(
*
value_
);
}
operator
std
::
shared_ptr
<
T
>*
()
{
return
&
value_
;
}
operator
std
::
shared_ptr
<
T
>&
()
{
return
value_
;
}
operator
std
::
shared_ptr
<
T
>&&
()
&&
{
return
std
::
move
(
value_
);
}
static
constexpr
auto
name
=
_
(
"shape"
);
template
<
typename
U
>
using
cast_op_type
=
pybind11
::
detail
::
cast_op_type
<
std
::
shared_ptr
<
T
>>
;
private:
static
handle
cast_impl
(
const
oneflow
::
Shape
&
src
)
{
return
reinterpret_steal
<
shape
>
(
oneflow
::
TensorSize_NewFromShape
(
src
)).
release
();
}
static
handle
cast_impl
(
const
std
::
shared_ptr
<
const
oneflow
::
Shape
>&
src
)
{
return
reinterpret_steal
<
shape
>
(
oneflow
::
TensorSize_NewFromShape
(
*
src
)).
release
();
}
protected:
std
::
shared_ptr
<
T
>
value_
;
};
template
<
>
struct
type_caster
<
oneflow
::
Shape
>
:
public
shape_type_caster
<
oneflow
::
Shape
>
{};
template
<
>
struct
type_caster
<
std
::
shared_ptr
<
oneflow
::
Shape
>>
:
public
shape_type_caster
<
oneflow
::
Shape
>
{};
template
<
>
struct
type_caster
<
std
::
shared_ptr
<
const
oneflow
::
Shape
>>
:
public
shape_type_caster
<
const
oneflow
::
Shape
>
{};
PYBIND11_NAMESPACE_END
(
detail
)
PYBIND11_NAMESPACE_END
(
PYBIND11_NAMESPACE
)
#endif // ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_
oneflow/api/python/framework/tensor.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/framework/tensor.h"
#include <pybind11/pybind11.h>
#include <Python.h>
#include "oneflow/api/python/exception/exception.h"
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/framework/tensortype.h"
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/functional/python_arg.h"
#include "oneflow/api/python/functional/functional_api.yaml.pybind.h"
#include "oneflow/api/python/functional/tensor_api.yaml.pybind.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/python/ofblob/ofblob.e.h"
#include "oneflow/api/python/utils/tensor_utils.h"
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/common/stride.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/placement_utils.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/tensor_index.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
one
{
#define ASSERT(x) (x).GetOrThrow()
#define ASSERT_PTR(x) (x).GetPtrOrThrow()
#define PY_XINCREF(p) (({ Py_XINCREF(p); }), (p))
#if PY_VERSION_HEX < 0x03070000
#define PYGETSET_NAME(name) const_cast<char*>(name)
#else
#define PYGETSET_NAME(name) (name)
#endif
PyTypeObject
*
PyTensorObject_Type
=
NULL
;
PyTypeObject
*
PyParameterObject_Type
=
NULL
;
static
int
PyTensorObject_init
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
auto
*
temp
=
functional
::
_legacy_tensor_ctor
(
NULL
,
args
,
kwargs
);
if
(
PyErr_Occurred
())
{
throw
py
::
error_already_set
();
}
auto
*
_self
=
(
PyTensorObject
*
)
self
;
_self
->
data
=
PyTensor_Unpack
(
temp
);
_self
->
data
->
set_pyobject
(
self
);
// reset temp data to prevent clearing the pyobject
// when the temp is deallocated
((
PyTensorObject
*
)
temp
)
->
data
.
reset
();
Py_XDECREF
(
temp
);
return
0
;
END_HANDLE_ERRORS_RET
(
-
1
)
}
static
void
PyTensorObject_dealloc
(
PyObject
*
self
)
{
auto
*
_self
=
(
PyTensorObject
*
)
self
;
// clear pyobject
if
(
_self
->
data
)
{
_self
->
data
->
set_pyobject
(
NULL
);
_self
->
data
.
reset
();
}
// clear __dict__
PyObject
**
dict_ptr
=
_PyObject_GetDictPtr
(
self
);
if
(
dict_ptr
)
{
Py_CLEAR
(
*
dict_ptr
);
}
auto
*
type
=
Py_TYPE
(
self
);
type
->
tp_free
(
self
);
Py_DECREF
(
type
);
}
static
int
PyParameterObject_init
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
PyObject
*
data
=
NULL
;
int
requires_grad
=
1
;
static
const
char
*
keywords
[
3
]
=
{
"data"
,
"requires_grad"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"O|p:__init__"
,
const_cast
<
char
**>
(
keywords
),
&
data
,
&
requires_grad
))
{
return
-
1
;
}
if
(
self
)
{
auto
*
_self
=
(
PyTensorObject
*
)
self
;
_self
->
data
=
ASSERT_PTR
(
Parameter
::
MakeTensor
(
PyTensor_Unpack
(
data
),
requires_grad
));
_self
->
data
->
set_pyobject
(
self
);
}
return
0
;
END_HANDLE_ERRORS_RET
(
-
1
)
}
static
Py_ssize_t
PyTensorObject_length
(
PyTensorObject
*
self
)
{
if
(
self
->
data
->
ndim
()
==
0
)
{
return
0
;
}
return
self
->
data
->
dim
(
0
);
}
static
PyObject
*
PyTensorObject_getitem
(
PyObject
*
self
,
Py_ssize_t
item
)
{
HANDLE_ERRORS
const
auto
&
p
=
PyTensor_Unpack
(
self
);
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
TensorGetItem
(
p
,
{
functional
::
detail
::
IndexItem
(
item
)})));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_subscript
(
PyObject
*
self
,
PyObject
*
item
)
{
HANDLE_ERRORS
const
auto
&
p
=
PyTensor_Unpack
(
self
);
functional
::
PythonArg
arg
(
item
);
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
TensorGetItem
(
p
,
arg
.
As
<
functional
::
TensorIndex
>
())));
END_HANDLE_ERRORS
}
static
PySequenceMethods
PyTensorObject_as_sequence
=
{
(
lenfunc
)
PyTensorObject_length
,
NULL
,
/*sq_concat*/
NULL
,
/*sq_repeat*/
(
ssizeargfunc
)
PyTensorObject_getitem
,
/*sq_item*/
};
extern
int
PyTensorObject_setitem
(
PyObject
*
,
PyObject
*
,
PyObject
*
);
static
PyMappingMethods
PyTensorObject_as_mapping
=
{
(
lenfunc
)
PyTensorObject_length
,
(
binaryfunc
)
PyTensorObject_subscript
,
(
objobjargproc
)
PyTensorObject_setitem
,
};
static
PyObject
*
PyTensorObject_storage_offset
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
storage_offset
());
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_stride
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
const
auto
&
stride
=
ASSERT_PTR
(
PyTensor_Unpack
(
self
)
->
stride
());
PyObject
*
tup
=
PyTuple_New
(
stride
->
size
());
for
(
int
i
=
0
;
i
<
stride
->
size
();
++
i
)
{
PyTuple_SetItem
(
tup
,
i
,
PyLong_FromUnsignedLong
(
stride
->
at
(
i
)));
}
return
tup
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_is_contiguous
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
is_contiguous
());
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_contiguous
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
PyTensor_New
(
PyTensor_Unpack
(
self
)
->
contiguous
());
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_contiguous_
(
PyObject
*
self
,
PyObject
*
unused
)
{
// NOTE: inplace version of contiguous
HANDLE_ERRORS
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
InplaceToContiguous
(
PyTensor_Unpack
(
self
))));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_pin_memory
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
PyTensor_New
(
PyTensor_Unpack
(
self
)
->
pin_memory
());
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_is_pinned
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
CHECK_JUST
(
PyTensor_Unpack
(
self
)
->
is_pinned
()));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_requires_grad_
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
int
requires_grad
=
1
;
static
const
char
*
keywords
[
2
]
=
{
"requires_grad"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"|p:requires_grad_"
,
const_cast
<
char
**>
(
keywords
),
&
requires_grad
))
{
return
NULL
;
}
ASSERT
(
PyTensor_Unpack
(
self
)
->
set_requires_grad
(
requires_grad
));
Py_XINCREF
(
self
);
return
self
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_retain_grad
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
const
auto
&
t
=
PyTensor_Unpack
(
self
);
if
(
!
t
->
requires_grad
())
{
return
PyErr_Format
(
PyExc_RuntimeError
,
"can't retain_grad on Tensor that has requires_grad=False"
);
}
ASSERT
(
t
->
set_retain_grad
(
true
));
Py_RETURN_NONE
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_detach
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
PyTensor_New
(
ASSERT_PTR
(
PyTensor_Unpack
(
self
)
->
detach
()));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_clone
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
PyTensor_New
(
ASSERT_PTR
(
PyTensor_Unpack
(
self
)
->
clone
()));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_zero_
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
ASSERT
(
EagerMirroredTensorZeros
(
PyTensor_Unpack
(
self
)));
Py_XINCREF
(
self
);
return
self
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_register_hook
(
PyObject
*
self
,
PyObject
*
hook
)
{
HANDLE_ERRORS
const
auto
&
_hook
=
py
::
cast
<
AutogradMeta
::
Hook
>
(
py
::
reinterpret_borrow
<
py
::
object
>
(
hook
));
ASSERT
(
RegisterTensorHook
(
PyTensor_Unpack
(
self
),
_hook
));
Py_RETURN_NONE
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject__register_post_grad_accumulation_hook
(
PyObject
*
self
,
PyObject
*
hook
)
{
HANDLE_ERRORS
const
auto
&
_hook
=
py
::
cast
<
AutogradMeta
::
Hook
>
(
py
::
reinterpret_borrow
<
py
::
object
>
(
hook
));
ASSERT
(
RegisterTensorPostGradAccumulationHook
(
PyTensor_Unpack
(
self
),
_hook
));
Py_RETURN_NONE
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_global_id
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
uint64_t
global_id
=
static_cast
<
uint64_t
>
(
ASSERT
(
PyTensor_Unpack
(
self
)
->
transport_token
()));
return
functional
::
CastToPyObject
(
global_id
);
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_check_meta_consistency
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
ASSERT
(
CheckMetaConsistency
(
PyTensor_Unpack
(
self
)));
Py_RETURN_NONE
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_to_numpy
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
const
auto
&
t
=
PyTensor_Unpack
(
self
);
DataType
data_type
=
t
->
dtype
()
->
data_type
();
switch
(
data_type
)
{
#define SWITCH_EAGER_TENSOR_TO_NUMPY(cpp_type, of_type) \
case of_type: return ASSERT(EagerMirroredTensorToNumpy<cpp_type>(self));
OF_PP_FOR_EACH_TUPLE
(
SWITCH_EAGER_TENSOR_TO_NUMPY
,
POD_DATA_TYPE_SEQ
)
case
DataType
::
kFloat16
:
return
ASSERT
(
EagerMirroredTensorToNumpy
<
float16
>
(
self
));
default:
{
return
PyErr_Format
(
PyExc_RuntimeError
,
"Invalid datatype"
);
}
}
#undef SWITCH_EAGER_TENSOR_TO_NUMPY
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_type
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
const
auto
&
tensor
=
PyTensor_Unpack
(
self
);
PyObject
*
tensor_type
=
NULL
;
int
non_blocking
=
0
;
static
const
char
*
keywords
[
3
]
=
{
"dtype"
,
"non_blocking"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"|Op:type"
,
const_cast
<
char
**>
(
keywords
),
&
tensor_type
,
&
non_blocking
))
{
return
NULL
;
}
// TODO: support non_blocking=True
if
(
non_blocking
==
1
)
{
return
PyErr_Format
(
PyExc_TypeError
,
"non_blocking=True is not supported yet"
);
}
if
(
tensor_type
==
NULL
)
{
tensor_type
=
PyTensorType_FromDTypeAndDeviceType
(
tensor
->
dtype
(),
ASSERT
(
tensor
->
device
())
->
enum_type
());
return
PyUnicode_FromString
(((
PyTensorType
*
)
tensor_type
)
->
name
);
}
if
(
PyUnicode_Check
(
tensor_type
))
{
tensor_type
=
PyTensorType_FromString
(
PyUnicode_AsUTF8
(
tensor_type
));
}
if
(
PyTensorType_Check
(
tensor_type
))
{
const
auto
&
dtype
=
PyTensorType_UnpackDType
(
tensor_type
);
DeviceType
device_type
=
PyTensorType_UnpackDevice
(
tensor_type
);
if
(
device_type
==
ASSERT
(
tensor
->
device
())
->
enum_type
())
{
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
To
(
tensor
,
dtype
,
/*copy=*/
false
)));
}
Optional
<
std
::
string
>
device
=
ASSERT
(
DeviceTag4DeviceType
(
device_type
));
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
To
(
tensor
,
device
,
dtype
,
/*copy=*/
false
)));
}
else
if
(
functional
::
PyDTypeCheck
(
tensor_type
))
{
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
To
(
tensor
,
functional
::
PyUnpackDType
(
tensor_type
),
/*copy=*/
false
)));
}
return
PyErr_Format
(
PyExc_TypeError
,
"dtype must be a type, str, or dtype object"
);
END_HANDLE_ERRORS
}
#define DEFINE_TENSOR_METHOD(T, type_proto) \
static PyObject* PyTensorObject__copy_to_numpy_##T(PyObject* self, PyObject* array) { \
HANDLE_ERRORS \
ASSERT(CopyBetweenMirroredTensorAndNumpy<T>(PyTensor_Unpack(self), array, \
BlobNumpyCopyUtil<T>::To, "const", \
/*block_host_until_done=*/
true)); \
Py_RETURN_NONE; \
END_HANDLE_ERRORS \
} \
static PyObject* PyTensorObject__copy_from_numpy_##T(PyObject* self, PyObject* array) { \
HANDLE_ERRORS \
auto* copied = PyArray_NewCopy((PyArrayObject*)array, NPY_CORDER); \
ASSERT(CopyBetweenMirroredTensorAndNumpy<T>(PyTensor_Unpack(self), copied, \
BlobNumpyCopyUtil<T>::From, "mut", \
/*block_host_until_done=*/
false)); \
Py_DECREF(copied); \
Py_RETURN_NONE; \
END_HANDLE_ERRORS \
}
OF_PP_FOR_EACH_TUPLE
(
DEFINE_TENSOR_METHOD
,
POD_DATA_TYPE_SEQ
)
#undef DEFINE_TENSOR_METHOD
static
PyObject
*
PyTensorObject__get_copy_mirrored_tensor_to_numpy_func_name
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
GetCopyMirroredTensorToNumpyFuncName
(
PyTensor_Unpack
(
self
)
->
dtype
()
->
data_type
()));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject__get_copy_mirrored_tensor_from_numpy_func_name
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
GetCopyMirroredTensorFromNumpyFuncName
(
PyTensor_Unpack
(
self
)
->
dtype
()
->
data_type
()));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject__register_storage_delete_hook
(
PyObject
*
self
,
PyObject
*
hook
)
{
HANDLE_ERRORS
auto
_hook
=
py
::
cast
<
std
::
function
<
void
()
>>
(
py
::
reinterpret_borrow
<
py
::
object
>
(
hook
));
ASSERT
(
PyTensor_Unpack
(
self
)
->
RegisterStorageDeleteHook
(
_hook
));
Py_RETURN_NONE
;
END_HANDLE_ERRORS
}
static
std
::
vector
<
PyMethodDef
>
concat_method_def
(
PyMethodDef
methods
[],
PyMethodDef
extra_methods
[])
{
int
len1
=
0
;
int
len2
=
0
;
PyMethodDef
*
p1
=
methods
;
PyMethodDef
*
p2
=
extra_methods
;
while
((
p1
++
)
->
ml_name
!=
NULL
)
{
len1
++
;
}
while
((
p2
++
)
->
ml_name
!=
NULL
)
{
len2
++
;
}
std
::
vector
<
PyMethodDef
>
total_methods
(
len1
+
len2
+
1
);
for
(
int
i
=
0
;
i
<
len1
;
i
++
)
total_methods
[
i
]
=
methods
[
i
];
for
(
int
i
=
0
;
i
<
len2
;
i
++
)
total_methods
[
i
+
len1
]
=
extra_methods
[
i
];
total_methods
[
len1
+
len2
]
=
{
NULL
};
return
total_methods
;
}
static
PyMethodDef
PyTensorObject_methods
[]
=
{
{
"storage_offset"
,
PyTensorObject_storage_offset
,
METH_NOARGS
,
NULL
},
{
"stride"
,
PyTensorObject_stride
,
METH_NOARGS
,
NULL
},
{
"is_contiguous"
,
PyTensorObject_is_contiguous
,
METH_NOARGS
,
NULL
},
{
"contiguous"
,
PyTensorObject_contiguous
,
METH_NOARGS
,
NULL
},
{
"contiguous_"
,
PyTensorObject_contiguous_
,
METH_NOARGS
,
NULL
},
{
"pin_memory"
,
PyTensorObject_pin_memory
,
METH_NOARGS
,
NULL
},
{
"is_pinned"
,
PyTensorObject_is_pinned
,
METH_NOARGS
,
NULL
},
{
"requires_grad_"
,
(
PyCFunction
)
PyTensorObject_requires_grad_
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"retain_grad"
,
PyTensorObject_retain_grad
,
METH_NOARGS
,
NULL
},
{
"detach"
,
PyTensorObject_detach
,
METH_NOARGS
,
NULL
},
{
"clone"
,
PyTensorObject_clone
,
METH_NOARGS
,
NULL
},
{
"zero_"
,
PyTensorObject_zero_
,
METH_NOARGS
,
NULL
},
{
"register_hook"
,
PyTensorObject_register_hook
,
METH_O
,
NULL
},
{
"_register_post_grad_accumulation_hook"
,
PyTensorObject__register_post_grad_accumulation_hook
,
METH_O
,
NULL
},
{
"global_id"
,
PyTensorObject_global_id
,
METH_NOARGS
,
NULL
},
{
"check_meta_consistency"
,
PyTensorObject_check_meta_consistency
,
METH_NOARGS
,
NULL
},
{
"to_numpy"
,
PyTensorObject_to_numpy
,
METH_NOARGS
,
NULL
},
{
"type"
,
(
PyCFunction
)
PyTensorObject_type
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
#define DEFINE_TENSOR_METHOD(T, type_proto) \
{"_copy_to_numpy_" #T, PyTensorObject__copy_to_numpy_##T, METH_O, NULL}, \
{"_copy_from_numpy_" #T, PyTensorObject__copy_from_numpy_##T, METH_O, NULL},
OF_PP_FOR_EACH_TUPLE
(
DEFINE_TENSOR_METHOD
,
POD_DATA_TYPE_SEQ
)
#undef DEFINE_TENSOR_METHOD
{
"_get_copy_mirrored_tensor_to_numpy_func_name"
,
PyTensorObject__get_copy_mirrored_tensor_to_numpy_func_name
,
METH_NOARGS
,
NULL
},
{
"_get_copy_mirrored_tensor_from_numpy_func_name"
,
PyTensorObject__get_copy_mirrored_tensor_from_numpy_func_name
,
METH_NOARGS
,
NULL
},
{
"_register_storage_delete_hook"
,
PyTensorObject__register_storage_delete_hook
,
METH_O
,
NULL
},
{
NULL
}};
static
PyObject
*
PyTensorObject_ndim
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
ndim
());
}
static
PyObject
*
PyTensorObject_shape
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
shape
());
}
static
PyObject
*
PyTensorObject_dtype
(
PyObject
*
self
,
void
*
unused
)
{
HANDLE_ERRORS
const
Symbol
<
DType
>*
dtype
=
&
ASSERT
(
DType
::
Get
(
PyTensor_Unpack
(
self
)
->
dtype
()
->
data_type
()));
return
functional
::
CastToPyObject
(
dtype
);
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_is_cuda
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
is_cuda
());
}
static
PyObject
*
PyTensorObject_grad
(
PyObject
*
self
,
void
*
unused
)
{
HANDLE_ERRORS
return
PyTensor_New
(
ASSERT_PTR
(
PyTensor_Unpack
(
self
)
->
acc_grad
()));
END_HANDLE_ERRORS
}
static
int
PyTensorObject_set_grad
(
PyObject
*
self
,
PyObject
*
grad
,
void
*
unused
)
{
HANDLE_ERRORS
const
auto
&
t
=
PyTensor_Unpack
(
self
);
if
(
self
==
grad
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"can't assign Tensor as its own grad"
);
}
if
(
grad
&&
grad
!=
Py_None
)
{
ASSERT
(
t
->
set_acc_grad
(
ASSERT_PTR
(
PyTensor_Unpack
(
grad
)
->
detach
())));
}
else
{
ASSERT
(
t
->
set_acc_grad
(
NULL
));
}
return
0
;
END_HANDLE_ERRORS_RET
(
-
1
)
}
static
PyObject
*
PyTensorObject__is_grad_acc_inplace
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
autograd_meta
()
->
is_grad_acc_inplace
());
}
static
int
PyTensorObject_set__is_grad_acc_inplace
(
PyObject
*
self
,
PyObject
*
is_inplace
,
void
*
unused
)
{
PyTensor_Unpack
(
self
)
->
mut_autograd_meta
()
->
set_is_grad_acc_inplace
(
is_inplace
);
return
0
;
}
static
PyObject
*
PyTensorObject_data
(
PyObject
*
self
,
void
*
unused
)
{
HANDLE_ERRORS
return
PyTensor_New
(
ASSERT_PTR
(
PyTensor_Unpack
(
self
)
->
data
()));
END_HANDLE_ERRORS
}
static
int
PyTensorObject_set_data
(
PyObject
*
self
,
PyObject
*
data
,
void
*
unused
)
{
HANDLE_ERRORS
const
auto
&
t
=
PyTensor_Unpack
(
self
);
auto
hooks
=
t
->
autograd_meta
()
->
hooks
();
ASSERT
(
t
->
set_data
(
PyTensor_Unpack
(
data
)));
// Re-register hooks
for
(
const
auto
&
hook
:
hooks
)
{
ASSERT
(
RegisterTensorHook
(
t
,
hook
));
}
return
0
;
END_HANDLE_ERRORS_RET
(
-
1
)
}
static
PyObject
*
PyTensorObject_grad_fn
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
grad_fn_node
());
}
static
PyObject
*
PyTensorObject_is_leaf
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
is_leaf
());
}
static
PyObject
*
PyTensorObject_requires_grad
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
requires_grad
());
}
static
int
PyTensorObject_set_requires_grad
(
PyObject
*
self
,
PyObject
*
requires_grad
,
void
*
unused
)
{
HANDLE_ERRORS
const
auto
&
t
=
PyTensor_Unpack
(
self
);
CHECK_OR_THROW
(
t
->
is_leaf
())
<<
Error
::
RuntimeError
()
<<
"You can only change requires_grad flags of leaf tensors."
;
ASSERT
(
t
->
set_requires_grad
(
requires_grad
==
Py_True
));
return
0
;
END_HANDLE_ERRORS_RET
(
-
1
)
}
static
PyObject
*
PyTensorObject_is_lazy
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
is_lazy
());
}
static
PyObject
*
PyTensorObject_is_eager
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
is_eager
());
}
static
PyObject
*
PyTensorObject_is_global
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
is_consistent
());
}
static
PyObject
*
PyTensorObject_is_local
(
PyObject
*
self
,
void
*
unused
)
{
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
is_local
());
}
static
PyObject
*
PyTensorObject__tensor_buffer_shapes_and_dtypes
(
PyObject
*
self
,
void
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
MaybeGetTensorBufferShapesAndDTypes
(
PyTensor_Unpack
(
self
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_device
(
PyObject
*
self
,
void
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
device
());
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_placement
(
PyObject
*
self
,
void
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
parallel_desc
());
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_sbp
(
PyObject
*
self
,
void
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
TensorGetPyTupleOfSbp
(
*
PyTensor_Unpack
(
self
)));
END_HANDLE_ERRORS
}
// NOLINTNEXTLINE
static
PyGetSetDef
PyTensorObject_properties
[]
=
{
{
PYGETSET_NAME
(
"ndim"
),
(
getter
)
PyTensorObject_ndim
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"shape"
),
(
getter
)
PyTensorObject_shape
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"dtype"
),
(
getter
)
PyTensorObject_dtype
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"is_cuda"
),
(
getter
)
PyTensorObject_is_cuda
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"grad"
),
(
getter
)
PyTensorObject_grad
,
(
setter
)
PyTensorObject_set_grad
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"_is_grad_acc_inplace"
),
(
getter
)
PyTensorObject__is_grad_acc_inplace
,
(
setter
)
PyTensorObject_set__is_grad_acc_inplace
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"data"
),
(
getter
)
PyTensorObject_data
,
(
setter
)
PyTensorObject_set_data
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"grad_fn"
),
(
getter
)
PyTensorObject_grad_fn
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"is_leaf"
),
(
getter
)
PyTensorObject_is_leaf
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"requires_grad"
),
(
getter
)
PyTensorObject_requires_grad
,
(
setter
)
PyTensorObject_set_requires_grad
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"is_lazy"
),
(
getter
)
PyTensorObject_is_lazy
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"is_eager"
),
(
getter
)
PyTensorObject_is_eager
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"is_global"
),
(
getter
)
PyTensorObject_is_global
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"is_local"
),
(
getter
)
PyTensorObject_is_local
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"_tensor_buffer_shapes_and_dtypes"
),
(
getter
)
PyTensorObject__tensor_buffer_shapes_and_dtypes
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"device"
),
(
getter
)
PyTensorObject_device
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"placement"
),
(
getter
)
PyTensorObject_placement
,
NULL
,
NULL
,
NULL
},
{
PYGETSET_NAME
(
"sbp"
),
(
getter
)
PyTensorObject_sbp
,
NULL
,
NULL
,
NULL
},
{
NULL
}};
// create a Tensor instance
static
PyObject
*
TensorMetaCls_call
(
PyObject
*
type
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
return
PyType_Type
.
tp_call
(
type
,
args
,
kwargs
);
}
static
void
TensorMetaCls_dealloc
(
PyObject
*
type
)
{
PyType_Type
.
tp_dealloc
(
type
);
}
static
PyHeapTypeObject
*
MakeTensorMetaclass
()
{
PyObject
*
name
=
PyUnicode_FromString
(
"_TensorMeta"
);
auto
*
heap_type
=
(
PyHeapTypeObject
*
)
PyType_Type
.
tp_alloc
(
&
PyType_Type
,
0
);
heap_type
->
ht_name
=
name
;
heap_type
->
ht_qualname
=
PY_XINCREF
(
name
);
auto
*
type
=
&
heap_type
->
ht_type
;
type
->
tp_name
=
"_TensorMeta"
;
type
->
tp_base
=
PY_XINCREF
(
&
PyType_Type
);
type
->
tp_flags
=
Py_TPFLAGS_DEFAULT
|
Py_TPFLAGS_BASETYPE
|
Py_TPFLAGS_HEAPTYPE
;
type
->
tp_call
=
TensorMetaCls_call
;
type
->
tp_dealloc
=
TensorMetaCls_dealloc
;
if
(
PyType_Ready
(
type
)
<
0
)
{
return
NULL
;
}
PyObject_SetAttrString
((
PyObject
*
)
type
,
"__module__"
,
PyUnicode_FromString
(
"oneflow._C"
));
return
heap_type
;
}
extern
PyNumberMethods
PyTensorObject_as_number
;
extern
PyObject
*
PyTensorObject_richcompare
(
PyObject
*
,
PyObject
*
,
int
);
extern
PyMethodDef
PyTensorObject_extra_methods
[];
static
PyHeapTypeObject
*
TensorMetaclass_Type
=
MakeTensorMetaclass
();
static
PyTypeObject
*
MakeTensorType
()
{
PyObject
*
name
=
PyUnicode_FromString
(
"Tensor"
);
auto
*
metaclass
=
&
TensorMetaclass_Type
->
ht_type
;
auto
*
heap_type
=
(
PyHeapTypeObject
*
)
metaclass
->
tp_alloc
(
metaclass
,
0
);
if
(
!
heap_type
)
{
return
NULL
;
}
heap_type
->
ht_name
=
name
;
heap_type
->
ht_qualname
=
PY_XINCREF
(
name
);
auto
*
type
=
&
heap_type
->
ht_type
;
type
->
tp_name
=
"Tensor"
;
type
->
tp_basicsize
=
sizeof
(
PyTensorObject
);
type
->
tp_init
=
PyTensorObject_init
;
type
->
tp_dealloc
=
PyTensorObject_dealloc
;
type
->
tp_getset
=
PyTensorObject_properties
;
static
std
::
vector
<
PyMethodDef
>
total_methods
=
concat_method_def
(
PyTensorObject_methods
,
PyTensorObject_extra_methods
);
type
->
tp_methods
=
total_methods
.
data
();
type
->
tp_as_number
=
&
PyTensorObject_as_number
;
type
->
tp_as_sequence
=
&
PyTensorObject_as_sequence
;
type
->
tp_as_mapping
=
&
PyTensorObject_as_mapping
;
type
->
tp_richcompare
=
PyTensorObject_richcompare
;
type
->
tp_hash
=
(
hashfunc
)
_Py_HashPointer
;
type
->
tp_flags
=
Py_TPFLAGS_DEFAULT
|
Py_TPFLAGS_BASETYPE
|
Py_TPFLAGS_HEAPTYPE
;
if
(
PyType_Ready
(
type
)
<
0
)
{
return
NULL
;
}
PyObject_SetAttrString
((
PyObject
*
)
type
,
"__module__"
,
PyUnicode_FromString
(
"oneflow"
));
return
type
;
}
static
PyTypeObject
*
MakeParameterType
()
{
PyObject
*
name
=
PyUnicode_FromString
(
"Parameter"
);
auto
*
metaclass
=
&
TensorMetaclass_Type
->
ht_type
;
auto
*
heap_type
=
(
PyHeapTypeObject
*
)
metaclass
->
tp_alloc
(
metaclass
,
0
);
if
(
!
heap_type
)
{
return
NULL
;
}
heap_type
->
ht_name
=
name
;
heap_type
->
ht_qualname
=
PY_XINCREF
(
name
);
auto
*
type
=
&
heap_type
->
ht_type
;
type
->
tp_name
=
"Parameter"
;
type
->
tp_basicsize
=
sizeof
(
PyTensorObject
);
type
->
tp_init
=
PyParameterObject_init
;
type
->
tp_base
=
PY_XINCREF
(
PyTensorObject_Type
);
type
->
tp_flags
=
Py_TPFLAGS_DEFAULT
|
Py_TPFLAGS_BASETYPE
|
Py_TPFLAGS_HEAPTYPE
;
if
(
PyType_Ready
(
type
)
<
0
)
{
return
NULL
;
}
PyObject_SetAttrString
((
PyObject
*
)
type
,
"__module__"
,
PyUnicode_FromString
(
"oneflow.nn"
));
return
type
;
}
PyObject
*
PyTensor_New
(
const
std
::
shared_ptr
<
Tensor
>&
data
)
{
if
(
!
data
)
{
Py_RETURN_NONE
;
}
if
(
data
->
pyobject
())
{
return
PY_XINCREF
((
PyObject
*
)(
data
->
pyobject
()));
}
auto
*
self
=
(
PyTensorObject
*
)
PyTensorObject_Type
->
tp_alloc
(
PyTensorObject_Type
,
0
);
if
(
self
)
{
self
->
data
=
data
;
self
->
data
->
set_pyobject
(
self
);
}
return
(
PyObject
*
)
self
;
}
PyObject
*
PyParameter_New
(
const
std
::
shared_ptr
<
Parameter
>&
data
)
{
if
(
!
data
)
{
Py_RETURN_NONE
;
}
if
(
data
->
pyobject
())
{
return
PY_XINCREF
((
PyObject
*
)(
data
->
pyobject
()));
}
auto
*
self
=
(
PyTensorObject
*
)
PyTensorObject_Type
->
tp_alloc
(
PyParameterObject_Type
,
0
);
if
(
self
)
{
self
->
data
=
data
;
self
->
data
->
set_pyobject
(
self
);
}
return
(
PyObject
*
)
self
;
}
PyObject
*
PyParameter_New
(
const
std
::
shared_ptr
<
Tensor
>&
data
,
bool
requires_grad
)
{
if
(
!
data
)
{
Py_RETURN_NONE
;
}
auto
*
self
=
(
PyTensorObject
*
)
PyTensorObject_Type
->
tp_alloc
(
PyParameterObject_Type
,
0
);
if
(
self
)
{
self
->
data
=
ASSERT_PTR
(
Parameter
::
MakeTensor
(
data
,
requires_grad
));
self
->
data
->
set_pyobject
(
self
);
}
return
(
PyObject
*
)
self
;
}
}
// namespace one
}
// namespace oneflow
#undef ASSERT
#undef ASSERT_PTR
using
namespace
oneflow
::
one
;
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
PyTensorObject_Type
=
MakeTensorType
();
PyParameterObject_Type
=
MakeParameterType
();
if
(
PyTensorObject_Type
&&
PyModule_AddObject
(
m
.
ptr
(),
"Tensor"
,
(
PyObject
*
)
PyTensorObject_Type
)
<
0
)
{
return
;
}
auto
nn
=
m
.
def_submodule
(
"nn"
);
if
(
PyParameterObject_Type
&&
PyModule_AddObject
(
nn
.
ptr
(),
"Parameter"
,
(
PyObject
*
)
PyParameterObject_Type
)
<
0
)
{
return
;
}
}
oneflow/api/python/framework/tensor.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_TENSOR_H_
#define ONEFLOW_API_PYTHON_FRAMEWORK_TENSOR_H_
#include <Python.h>
#include "oneflow/core/framework/tensor.h"
namespace
oneflow
{
namespace
one
{
typedef
struct
{
PyObject_HEAD
;
std
::
shared_ptr
<
Tensor
>
data
;
}
PyTensorObject
;
extern
PyTypeObject
*
PyTensorObject_Type
;
extern
PyTypeObject
*
PyParameterObject_Type
;
inline
bool
PyTensor_Check
(
PyObject
*
op
)
{
return
PyObject_TypeCheck
(
op
,
PyTensorObject_Type
);
}
inline
bool
PyTensor_CheckExact
(
PyObject
*
op
)
{
return
op
->
ob_type
==
PyTensorObject_Type
||
op
->
ob_type
==
PyParameterObject_Type
;
}
inline
std
::
shared_ptr
<
Tensor
>&
PyTensor_Unpack
(
PyObject
*
op
)
{
assert
(
PyTensor_Check
(
op
));
return
((
PyTensorObject
*
)
op
)
->
data
;
}
PyObject
*
PyTensor_New
(
const
std
::
shared_ptr
<
Tensor
>&
data
);
PyObject
*
PyParameter_New
(
const
std
::
shared_ptr
<
Parameter
>&
data
);
PyObject
*
PyParameter_New
(
const
std
::
shared_ptr
<
Tensor
>&
data
,
bool
requires_grad
);
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_FRAMEWORK_TENSOR_H_
oneflow/api/python/framework/tensor_functions.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <Python.h>
#include "oneflow/api/python/exception/exception.h"
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/functional/functional_api.yaml.pybind.h"
#include "oneflow/core/common/shape_vec.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/wrap_dim_utils.h"
namespace
oneflow
{
namespace
one
{
#define ASSERT(x) (x).GetOrThrow()
#define ASSERT_PTR(x) (x).GetPtrOrThrow()
using
functional
::
PyObjectPtr
;
static
PyObject
*
concat_self
(
PyObject
*
self
,
PyObject
*
args
)
{
PyObjectPtr
self_tuple
(
PyTuple_Pack
(
1
,
self
));
PyObject
*
tuple
=
PySequence_Concat
(
self_tuple
.
get
(),
args
);
CHECK_OR_THROW
(
tuple
!=
NULL
);
return
tuple
;
}
#define NB_UNARY_FUNC(func_name, bind_func) \
static PyObject* func_name(PyObject* self) { \
HANDLE_ERRORS \
PyObjectPtr tuple(PyTuple_Pack(1, self)); \
auto* result = bind_func(NULL, tuple.get(), NULL); \
if (PyErr_Occurred()) { throw py::error_already_set(); } \
return result; \
END_HANDLE_ERRORS \
}
#define NB_BINARY_FUNC(func_name, bind_func) \
static PyObject* func_name(PyObject* a, PyObject* b) { \
HANDLE_ERRORS \
PyObjectPtr tuple(PyTuple_Pack(2, a, b)); \
auto* result = bind_func(NULL, tuple.get(), NULL); \
if (PyErr_Occurred()) { throw py::error_already_set(); } \
return result; \
END_HANDLE_ERRORS \
}
NB_UNARY_FUNC
(
PyTensorObject_nb_absolute
,
functional
::
abs
);
NB_UNARY_FUNC
(
PyTensorObject_nb_negative
,
functional
::
negative
);
// TODO: not implemented yet
// NB_UNARY_FUNC(PyTensorObject_positive, functional::positive);
NB_BINARY_FUNC
(
PyTensorObject_nb_add
,
functional
::
add
);
NB_BINARY_FUNC
(
PyTensorObject_nb_sub
,
functional
::
sub
);
NB_BINARY_FUNC
(
PyTensorObject_nb_mul
,
functional
::
mul
);
NB_BINARY_FUNC
(
PyTensorObject_nb_fmod
,
functional
::
fmod
);
NB_BINARY_FUNC
(
PyTensorObject_nb_div
,
functional
::
div
);
NB_BINARY_FUNC
(
PyTensorObject_nb_and
,
functional
::
logical_and
);
NB_BINARY_FUNC
(
PyTensorObject_nb_xor
,
functional
::
logical_xor
);
NB_BINARY_FUNC
(
PyTensorObject_nb_or
,
functional
::
logical_or
);
NB_BINARY_FUNC
(
PyTensorObject_nb_floor_div
,
functional
::
floor_divide
);
NB_BINARY_FUNC
(
PyTensorObject_nb_true_div
,
functional
::
div
);
NB_BINARY_FUNC
(
PyTensorObject_nb_matrix_multiply
,
functional
::
matmul
);
static
PyObject
*
PyTensorObject_nb_pow
(
PyObject
*
a
,
PyObject
*
b
,
PyObject
*
unsed
)
{
HANDLE_ERRORS
PyObjectPtr
tuple
(
PyTuple_Pack
(
2
,
a
,
b
));
PyObject
*
result
=
functional
::
pow
(
NULL
,
tuple
.
get
(),
NULL
);
if
(
PyErr_Occurred
())
{
throw
py
::
error_already_set
();
}
return
result
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_nb_invert
(
PyObject
*
self
)
{
HANDLE_ERRORS
CHECK_OR_THROW
(
PyTensor_Unpack
(
self
)
->
dtype
()
->
data_type
()
==
DataType
::
kBool
)
<<
"~ (operator.invert) is only implemented on integer and Boolean-type tensors"
;
PyObjectPtr
tuple
(
PyTuple_Pack
(
1
,
self
));
PyObject
*
result
=
functional
::
logical_not
(
NULL
,
tuple
.
get
(),
NULL
);
if
(
PyErr_Occurred
())
{
throw
py
::
error_already_set
();
}
return
result
;
END_HANDLE_ERRORS
}
#define NB_INPLACE_BINARY_FUNC(func_name, bind_func) \
static PyObject* func_name(PyObject* a, PyObject* b) { \
HANDLE_ERRORS \
PyObjectPtr tuple(PyTuple_Pack(2, a, b)); \
PyObjectPtr dict(PyDict_New()); \
CHECK_OR_THROW(PyDict_SetItemString(dict.get(), "inplace", Py_True) > -1); \
PyObject* result = bind_func(NULL, tuple.get(), dict.get()); \
if (PyErr_Occurred()) { throw py::error_already_set(); } \
return result; \
END_HANDLE_ERRORS \
}
// inplace operators
NB_INPLACE_BINARY_FUNC
(
PyTensorObject_nb_inplace_add
,
functional
::
add
);
NB_INPLACE_BINARY_FUNC
(
PyTensorObject_nb_inplace_sub
,
functional
::
sub
);
// The interface of inplace mul not mul(*, inplace=True) but mul_
NB_BINARY_FUNC
(
PyTensorObject_nb_inplace_mul
,
functional
::
mul_
);
NB_BINARY_FUNC
(
PyTensorObject_nb_inplace_true_div
,
functional
::
div_
);
PyObject
*
PyTensorObject_nb_inplace_pow
(
PyObject
*
a
,
PyObject
*
b
,
PyObject
*
unsed
)
{
HANDLE_ERRORS
PyObjectPtr
tuple
(
PyTuple_Pack
(
2
,
a
,
b
));
PyObjectPtr
dict
(
PyDict_New
());
CHECK_OR_THROW
(
PyDict_SetItemString
(
dict
.
get
(),
"inplace"
,
Py_True
)
>
-
1
);
auto
*
result
=
functional
::
pow
(
NULL
,
tuple
.
get
(),
NULL
);
if
(
PyErr_Occurred
())
{
throw
py
::
error_already_set
();
}
return
result
;
END_HANDLE_ERRORS
}
PyNumberMethods
PyTensorObject_as_number
=
{
PyTensorObject_nb_add
,
// nb_add
PyTensorObject_nb_sub
,
// nb_subtract
PyTensorObject_nb_mul
,
// nb_multiply
PyTensorObject_nb_fmod
,
// nb_remainder
NULL
,
// nb_divmod
PyTensorObject_nb_pow
,
// nb_power
PyTensorObject_nb_negative
,
// nb_negative
NULL
,
// nb_positive
PyTensorObject_nb_absolute
,
// nb_absolute
NULL
,
// nb_bool
PyTensorObject_nb_invert
,
// nb_invert
NULL
,
// nb_lshift
NULL
,
// nb_rshift
PyTensorObject_nb_and
,
// nb_and
PyTensorObject_nb_xor
,
// nb_xor
PyTensorObject_nb_or
,
// nb_or
NULL
,
// nb_int
NULL
,
// nb_reserved
NULL
,
// nb_float
PyTensorObject_nb_inplace_add
,
// nb_inplace_add
PyTensorObject_nb_inplace_sub
,
// nb_inplace_sub
PyTensorObject_nb_inplace_mul
,
// nb_inplace_mul
NULL
,
// nb_inplace_remainder
PyTensorObject_nb_inplace_pow
,
// nb_inplace_pow
NULL
,
// nb_inplace_lshift
NULL
,
// nb_inplace_rshift
NULL
,
// nb_inplace_and
NULL
,
// nb_inplace_xor
NULL
,
// nb_inplace_or
PyTensorObject_nb_floor_div
,
// nb_floor_div
PyTensorObject_nb_true_div
,
// nb_true_div
NULL
,
// nb_inplace_floor_div
PyTensorObject_nb_inplace_true_div
,
// nb_inplace_true_div
NULL
,
// nb_index
PyTensorObject_nb_matrix_multiply
,
// nb_matrix_multiply
NULL
,
// nb_inplace_matrix_multiply
};
// extra methods
// functions that accept only one Tensor
#define UNARY_METHOD(func_name, bind_func) \
static PyObject* func_name(PyObject* self, PyObject* unused) { \
HANDLE_ERRORS \
return PyTensor_New(ASSERT_PTR(bind_func(PyTensor_Unpack(self)))); \
END_HANDLE_ERRORS \
}
UNARY_METHOD
(
PyTensorObject_abs
,
functional
::
Abs
);
UNARY_METHOD
(
PyTensorObject_exp
,
functional
::
Exp
);
UNARY_METHOD
(
PyTensorObject_floor
,
functional
::
Floor
);
UNARY_METHOD
(
PyTensorObject_floor_
,
functional
::
Floor_
);
UNARY_METHOD
(
PyTensorObject_sign
,
functional
::
Sign
);
UNARY_METHOD
(
PyTensorObject_gelu
,
functional
::
Gelu
);
UNARY_METHOD
(
PyTensorObject_mish
,
functional
::
Mish
);
UNARY_METHOD
(
PyTensorObject_negative
,
functional
::
Negative
);
UNARY_METHOD
(
PyTensorObject_sigmoid
,
functional
::
Sigmoid
);
UNARY_METHOD
(
PyTensorObject_silu
,
functional
::
Silu
);
UNARY_METHOD
(
PyTensorObject_selu
,
functional
::
Selu
);
UNARY_METHOD
(
PyTensorObject_softsign
,
functional
::
SoftSign
);
UNARY_METHOD
(
PyTensorObject_log1p
,
functional
::
Log1p
);
UNARY_METHOD
(
PyTensorObject_log2
,
functional
::
Log2
);
UNARY_METHOD
(
PyTensorObject_reciprocal
,
functional
::
Reciprocal
);
UNARY_METHOD
(
PyTensorObject_ceil
,
functional
::
Ceil
);
UNARY_METHOD
(
PyTensorObject_erf
,
functional
::
Erf
);
UNARY_METHOD
(
PyTensorObject_erfc
,
functional
::
Erfc
);
UNARY_METHOD
(
PyTensorObject_erfinv
,
functional
::
Erfinv
);
UNARY_METHOD
(
PyTensorObject_erfinv_
,
functional
::
ErfinvInplace
);
UNARY_METHOD
(
PyTensorObject_expm1
,
functional
::
Expm1
);
UNARY_METHOD
(
PyTensorObject_log
,
functional
::
Log
);
UNARY_METHOD
(
PyTensorObject_rsqrt
,
functional
::
Rsqrt
);
UNARY_METHOD
(
PyTensorObject_sqrt
,
functional
::
Sqrt
);
UNARY_METHOD
(
PyTensorObject_square
,
functional
::
Square
);
UNARY_METHOD
(
PyTensorObject_round
,
functional
::
Round
);
UNARY_METHOD
(
PyTensorObject_t
,
functional
::
TransposeAllDimFunction
);
UNARY_METHOD
(
PyTensorObject_isnan
,
functional
::
IsNan
);
UNARY_METHOD
(
PyTensorObject_isinf
,
functional
::
IsInf
);
UNARY_METHOD
(
PyTensorObject_sin
,
functional
::
Sin
);
UNARY_METHOD
(
PyTensorObject_sin_
,
functional
::
Sin_
);
UNARY_METHOD
(
PyTensorObject_asin
,
functional
::
Asin
);
UNARY_METHOD
(
PyTensorObject_cos
,
functional
::
Cos
);
UNARY_METHOD
(
PyTensorObject_acos
,
functional
::
Acos
);
UNARY_METHOD
(
PyTensorObject_tan
,
functional
::
Tan
);
UNARY_METHOD
(
PyTensorObject_atan
,
functional
::
Atan
);
UNARY_METHOD
(
PyTensorObject_sinh
,
functional
::
Sinh
);
UNARY_METHOD
(
PyTensorObject_asinh
,
functional
::
Asinh
);
UNARY_METHOD
(
PyTensorObject_cosh
,
functional
::
Cosh
);
UNARY_METHOD
(
PyTensorObject_acosh
,
functional
::
Acosh
);
UNARY_METHOD
(
PyTensorObject_tanh
,
functional
::
Tanh
);
UNARY_METHOD
(
PyTensorObject_atanh
,
functional
::
Atanh
);
UNARY_METHOD
(
PyTensorObject_logical_not
,
functional
::
LogicalNot
);
// functions that directly pass arguments without parsing
#define DIRECT_PASS_FUNC(func_name, bind_func) \
static PyObject* func_name(PyObject* self, PyObject* args, PyObject* kwargs) { \
HANDLE_ERRORS \
PyObjectPtr concat_args(concat_self(self, args)); \
PyObject* result = bind_func(NULL, concat_args.get(), kwargs); \
if (PyErr_Occurred()) { throw py::error_already_set(); } \
return result; \
END_HANDLE_ERRORS \
}
DIRECT_PASS_FUNC
(
PyTensorObject_floor_divide
,
functional
::
floor_divide
)
DIRECT_PASS_FUNC
(
PyTensorObject_atan2
,
functional
::
atan2
)
DIRECT_PASS_FUNC
(
PyTensorObject_gt
,
functional
::
greater
)
DIRECT_PASS_FUNC
(
PyTensorObject_ge
,
functional
::
greater_equal
)
DIRECT_PASS_FUNC
(
PyTensorObject_div
,
functional
::
div
)
DIRECT_PASS_FUNC
(
PyTensorObject_div_
,
functional
::
div_
)
DIRECT_PASS_FUNC
(
PyTensorObject_mul
,
functional
::
mul
)
DIRECT_PASS_FUNC
(
PyTensorObject_mul_
,
functional
::
mul_
)
DIRECT_PASS_FUNC
(
PyTensorObject_fmod
,
functional
::
fmod
)
DIRECT_PASS_FUNC
(
PyTensorObject_logical_and
,
functional
::
logical_and
)
DIRECT_PASS_FUNC
(
PyTensorObject_logical_or
,
functional
::
logical_or
)
DIRECT_PASS_FUNC
(
PyTensorObject_logical_xor
,
functional
::
logical_xor
)
DIRECT_PASS_FUNC
(
PyTensorObject_ne
,
functional
::
not_equal
)
DIRECT_PASS_FUNC
(
PyTensorObject_lt
,
functional
::
less
)
DIRECT_PASS_FUNC
(
PyTensorObject_le
,
functional
::
less_equal
)
DIRECT_PASS_FUNC
(
PyTensorObject_bmm
,
functional
::
batch_matmul
)
DIRECT_PASS_FUNC
(
PyTensorObject_argmax
,
functional
::
argmax
)
DIRECT_PASS_FUNC
(
PyTensorObject_argmin
,
functional
::
argmin
)
DIRECT_PASS_FUNC
(
PyTensorObject_amin
,
functional
::
amin
)
DIRECT_PASS_FUNC
(
PyTensorObject_amax
,
functional
::
amax
)
DIRECT_PASS_FUNC
(
PyTensorObject_addcmul
,
functional
::
addcmul
)
DIRECT_PASS_FUNC
(
PyTensorObject_addcmul_
,
functional
::
addcmul_
)
DIRECT_PASS_FUNC
(
PyTensorObject_clip
,
functional
::
clip
)
DIRECT_PASS_FUNC
(
PyTensorObject_clip_
,
functional
::
clip_
)
DIRECT_PASS_FUNC
(
PyTensorObject_clamp
,
functional
::
clamp
)
DIRECT_PASS_FUNC
(
PyTensorObject_clamp_
,
functional
::
clamp_
)
DIRECT_PASS_FUNC
(
PyTensorObject_flatten
,
functional
::
flatten
)
DIRECT_PASS_FUNC
(
PyTensorObject_in_top_k
,
functional
::
in_top_k
)
DIRECT_PASS_FUNC
(
PyTensorObject_index_select
,
functional
::
index_select
)
DIRECT_PASS_FUNC
(
PyTensorObject_maximum
,
functional
::
maximum
)
DIRECT_PASS_FUNC
(
PyTensorObject_minimum
,
functional
::
minimum
)
DIRECT_PASS_FUNC
(
PyTensorObject_tril
,
functional
::
tril
)
DIRECT_PASS_FUNC
(
PyTensorObject_triu
,
functional
::
triu
)
DIRECT_PASS_FUNC
(
PyTensorObject_softmax
,
functional
::
softmax
)
DIRECT_PASS_FUNC
(
PyTensorObject_log_softmax
,
functional
::
log_softmax
)
DIRECT_PASS_FUNC
(
PyTensorObject_roll
,
functional
::
roll
)
DIRECT_PASS_FUNC
(
PyTensorObject_unbind
,
functional
::
unbind
)
DIRECT_PASS_FUNC
(
PyTensorObject_squeeze
,
functional
::
squeeze
)
DIRECT_PASS_FUNC
(
PyTensorObject_swapaxes
,
functional
::
swapaxes
)
DIRECT_PASS_FUNC
(
PyTensorObject_swapdims
,
functional
::
swapdims
)
DIRECT_PASS_FUNC
(
PyTensorObject_unfold
,
functional
::
unfold_tensor
)
DIRECT_PASS_FUNC
(
PyTensorObject_unsqueeze
,
functional
::
unsqueeze
)
DIRECT_PASS_FUNC
(
PyTensorObject_max
,
functional
::
max
)
DIRECT_PASS_FUNC
(
PyTensorObject_min
,
functional
::
min
)
DIRECT_PASS_FUNC
(
PyTensorObject_median
,
functional
::
median
)
DIRECT_PASS_FUNC
(
PyTensorObject_pow
,
functional
::
pow
)
DIRECT_PASS_FUNC
(
PyTensorObject_chunk
,
functional
::
chunk
)
DIRECT_PASS_FUNC
(
PyTensorObject_narrow
,
functional
::
narrow
)
DIRECT_PASS_FUNC
(
PyTensorObject_masked_fill
,
functional
::
masked_fill
)
DIRECT_PASS_FUNC
(
PyTensorObject_dot
,
functional
::
dot
)
// functions that parsing at Python C api layer
static
PyObject
*
PyTensorObject_byte
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
To
(
PyTensor_Unpack
(
self
),
DType
::
UInt8
(),
false
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_dim
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
ndim
());
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_nelement
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
nelement
());
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_element_size
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
functional
::
CastToPyObject
(
PyTensor_Unpack
(
self
)
->
dtype
()
->
bytes
());
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_get_device
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
DeviceType
device_type
=
ASSERT
(
PyTensor_Unpack
(
self
)
->
device
())
->
enum_type
();
CHECK_OR_THROW
(
device_type
==
DeviceType
::
kCUDA
)
<<
"get_device is only available for GPU tensor."
;
return
functional
::
CastToPyObject
(
ASSERT
(
PyTensor_Unpack
(
self
)
->
device
())
->
device_id
());
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_size
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
PyObject
*
idx_obj
=
Py_None
;
static
const
char
*
keywords
[
2
]
=
{
"idx"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"|O:size"
,
const_cast
<
char
**>
(
keywords
),
&
idx_obj
))
{
return
NULL
;
}
auto
shape
=
PyTensor_Unpack
(
self
)
->
shape
();
if
(
idx_obj
==
NULL
||
idx_obj
==
Py_None
)
return
TensorSize_NewFromShape
(
*
shape
);
int64_t
idx
=
PyLong_AsLongLong
(
idx_obj
);
int64_t
ndim
=
shape
->
NumAxes
();
idx
=
CHECK_JUST
(
maybe_wrap_dim
(
idx
,
ndim
));
idx
=
idx
<
0
?
idx
+
ndim
:
idx
;
return
PyLong_FromLongLong
(
shape
->
At
(
idx
));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_cast
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
PyObject
*
dtype
=
NULL
;
PyObject
*
pin_memory
=
Py_False
;
static
const
char
*
keywords
[
3
]
=
{
"dtype"
,
"pin_memory"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"O|O!:cast"
,
const_cast
<
char
**>
(
keywords
),
&
dtype
,
&
PyBool_Type
,
&
pin_memory
))
{
return
NULL
;
}
CHECK_OR_THROW
(
functional
::
PyDTypeCheck
(
dtype
))
<<
Error
::
TypeError
()
<<
"cast(): argument 'dtype' must be data type, but found "
<<
functional
::
PyStringAsString
(
PyObject_Str
((
PyObject
*
)
Py_TYPE
(
dtype
)));
const
auto
&
result
=
functional
::
Cast
(
PyTensor_Unpack
(
self
),
functional
::
PyUnpackDType
(
dtype
),
pin_memory
==
Py_True
);
return
PyTensor_New
(
ASSERT_PTR
(
result
));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_diag
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
int32_t
diagonal
=
0
;
static
const
char
*
keywords
[
2
]
=
{
"diagonal"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"|i:diag"
,
const_cast
<
char
**>
(
keywords
),
&
diagonal
))
{
return
NULL
;
}
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
Diag
(
PyTensor_Unpack
(
self
),
diagonal
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_diagonal
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
int32_t
offset
=
0
;
int32_t
dim1
=
0
;
int32_t
dim2
=
1
;
static
const
char
*
keywords
[
4
]
=
{
"offset"
,
"dim1"
,
"dim2"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"|iii:diagonal"
,
const_cast
<
char
**>
(
keywords
),
&
offset
,
&
dim1
,
&
dim2
))
{
return
NULL
;
}
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
Diagonal
(
PyTensor_Unpack
(
self
),
offset
,
dim1
,
dim2
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_matmul
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
PyObject
*
other
=
NULL
;
static
const
char
*
keywords
[
2
]
=
{
"other"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"O:matmul"
,
const_cast
<
char
**>
(
keywords
),
&
other
))
{
return
NULL
;
}
PyObjectPtr
concat_args
(
PyTuple_Pack
(
2
,
self
,
other
));
PyObject
*
result
=
functional
::
matmul
(
NULL
,
concat_args
.
get
(),
NULL
);
if
(
PyErr_Occurred
())
{
throw
py
::
error_already_set
();
}
return
result
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_reshape
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
PyObject
*
shape
=
args
;
if
(
PyTuple_Size
(
args
)
==
1
)
{
PyObject
*
item
=
PyTuple_GetItem
(
args
,
0
);
if
(
!
PyLong_Check
(
item
))
{
shape
=
item
;
}
}
PyObjectPtr
_args
=
PyObjectPtr
(
PyTuple_Pack
(
2
,
self
,
shape
));
PyObject
*
result
=
functional
::
reshape
(
NULL
,
_args
.
get
(),
kwargs
);
if
(
PyErr_Occurred
())
{
throw
py
::
error_already_set
();
}
return
result
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_reshape_as
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
auto
tensor
=
PyTensor_Unpack
(
self
);
PyObject
*
other
=
NULL
;
static
const
char
*
keywords
[
2
]
=
{
"other"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"O|:reshape_as"
,
const_cast
<
char
**>
(
keywords
),
&
other
))
{
return
NULL
;
}
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
Reshape
(
tensor
,
*
PyTensor_Unpack
(
other
)
->
shape
())));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_cpu
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
Optional
<
std
::
string
>
device
=
"cpu"
;
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
To
(
PyTensor_Unpack
(
self
),
device
,
NullOpt
,
false
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_cuda
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
PyObject
*
device_obj
=
Py_None
;
static
const
char
*
keywords
[
2
]
=
{
"device"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"|O:cuda"
,
const_cast
<
char
**>
(
keywords
),
&
device_obj
))
{
return
NULL
;
}
auto
tensor
=
PyTensor_Unpack
(
self
);
if
(
functional
::
PyDeviceCheck
(
device_obj
))
{
Optional
<
Symbol
<
Device
>>
device
=
functional
::
PyUnpackDevice
(
device_obj
);
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
To
(
tensor
,
device
,
NullOpt
,
false
)));
}
Optional
<
std
::
string
>
device_str
;
if
(
device_obj
==
Py_None
)
{
device_str
=
"cuda"
;
}
else
if
(
PyLong_Check
(
device_obj
))
{
device_str
=
"cuda:"
+
std
::
to_string
(
PyLong_AsLongLong
(
device_obj
));
}
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
To
(
tensor
,
device_str
,
tensor
->
dtype
(),
false
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_var
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
PyObject
*
dim_obj
=
Py_None
;
PyObject
*
unbiased_obj
=
Py_True
;
PyObject
*
keepdim_obj
=
Py_False
;
static
const
char
*
keywords
[
4
]
=
{
"dim"
,
"unbiased"
,
"keepdim"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"|OO!O!:var"
,
const_cast
<
char
**>
(
keywords
),
&
dim_obj
,
&
PyBool_Type
,
&
unbiased_obj
,
&
PyBool_Type
,
&
keepdim_obj
))
{
return
NULL
;
}
bool
unbiased
=
unbiased_obj
==
Py_True
;
bool
keepdim
=
keepdim_obj
==
Py_True
;
CHECK_OR_THROW
(
dim_obj
==
Py_None
||
PyLong_Check
(
dim_obj
)
||
functional
::
PyLongSequenceCheck
(
dim_obj
))
<<
Error
::
TypeError
()
<<
"var(): argument 'dim' must be int32 list, not "
<<
functional
::
PyStringAsString
(
PyObject_Str
((
PyObject
*
)
Py_TYPE
(
dim_obj
)));
auto
tensor
=
PyTensor_Unpack
(
self
);
if
(
dim_obj
==
Py_None
)
{
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
Variance
(
tensor
,
NullOpt
,
unbiased
,
keepdim
)));
}
std
::
vector
<
int32_t
>
dim
;
if
(
PyLong_Check
(
dim_obj
))
{
dim
.
emplace_back
(
static_cast
<
int32_t
>
(
PyLong_AsLong
(
dim_obj
)));
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
Variance
(
tensor
,
dim
,
unbiased
,
keepdim
)));
}
dim
=
functional
::
PyUnpackLongSequence
<
int32_t
>
(
dim_obj
);
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
Variance
(
tensor
,
dim
,
unbiased
,
keepdim
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_std
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
PyObject
*
dim_obj
=
Py_None
;
PyObject
*
unbiased_obj
=
Py_True
;
PyObject
*
keepdim_obj
=
Py_False
;
static
const
char
*
keywords
[
4
]
=
{
"dim"
,
"unbiased"
,
"keepdim"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"|OO!O!:std"
,
const_cast
<
char
**>
(
keywords
),
&
dim_obj
,
&
PyBool_Type
,
&
unbiased_obj
,
&
PyBool_Type
,
&
keepdim_obj
))
{
return
NULL
;
}
bool
unbiased
=
unbiased_obj
==
Py_True
;
bool
keepdim
=
keepdim_obj
==
Py_True
;
CHECK_OR_THROW
(
dim_obj
==
Py_None
||
PyLong_Check
(
dim_obj
)
||
functional
::
PyLongSequenceCheck
(
dim_obj
))
<<
Error
::
TypeError
()
<<
"std(): argument 'dim' must be int32 list, not "
<<
functional
::
PyStringAsString
(
PyObject_Str
((
PyObject
*
)
Py_TYPE
(
dim_obj
)));
auto
tensor
=
PyTensor_Unpack
(
self
);
if
(
dim_obj
==
Py_None
)
{
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
StandardDeviation
(
tensor
,
NullOpt
,
unbiased
,
keepdim
)));
}
std
::
vector
<
int32_t
>
dim
;
if
(
PyLong_Check
(
dim_obj
))
{
dim
.
emplace_back
(
static_cast
<
int32_t
>
(
PyLong_AsLong
(
dim_obj
)));
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
StandardDeviation
(
tensor
,
dim
,
unbiased
,
keepdim
)));
}
dim
=
functional
::
PyUnpackLongSequence
<
int32_t
>
(
dim_obj
);
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
StandardDeviation
(
tensor
,
dim
,
unbiased
,
keepdim
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_softplus
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
double
beta
=
1.0
;
double
threshold
=
20.0
;
static
const
char
*
keywords
[
3
]
=
{
"beta"
,
"threshold"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"dd:softplus"
,
const_cast
<
char
**>
(
keywords
),
&
beta
,
&
threshold
))
{
return
NULL
;
}
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
Softplus
(
PyTensor_Unpack
(
self
),
beta
,
threshold
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_relu
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
Relu
(
PyTensor_Unpack
(
self
),
false
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_relu_
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
Relu
(
PyTensor_Unpack
(
self
),
true
)));
END_HANDLE_ERRORS
}
#define REDUCE_FUNC(func_name, bind_func, whole_func) \
static PyObject* func_name(PyObject* self, PyObject* args, PyObject* kwargs) { \
HANDLE_ERRORS \
if ((args == NULL || PyTuple_Size(args) == 0) \
&& (kwargs == NULL || PyDict_Size(kwargs) == 0)) { \
return PyTensor_New(ASSERT_PTR(whole_func(PyTensor_Unpack(self)))); \
} \
PyObjectPtr concat_args(concat_self(self, args)); \
PyObject* result = bind_func(NULL, concat_args.get(), kwargs); \
if (PyErr_Occurred()) { throw py::error_already_set(); } \
return result; \
END_HANDLE_ERRORS \
}
REDUCE_FUNC
(
PyTensorObject_any
,
functional
::
reduce_any
,
functional
::
ReduceAnyWhole
)
REDUCE_FUNC
(
PyTensorObject_all
,
functional
::
reduce_all
,
functional
::
ReduceAllWhole
)
REDUCE_FUNC
(
PyTensorObject_sum
,
functional
::
reduce_sum
,
functional
::
ReduceSumWhole
)
REDUCE_FUNC
(
PyTensorObject_mean
,
functional
::
reduce_mean
,
functional
::
ReduceMeanWhole
)
#define DATATYPE_FUNC(func_name, dtype) \
static PyObject* func_name(PyObject* self, PyObject* unused) { \
HANDLE_ERRORS \
auto tensor = PyTensor_Unpack(self); \
return PyTensor_New(ASSERT_PTR(functional::To(tensor, dtype, false))); \
END_HANDLE_ERRORS \
}
DATATYPE_FUNC
(
PyTensorObject_int
,
DType
::
Int32
());
DATATYPE_FUNC
(
PyTensorObject_long
,
DType
::
Int64
());
DATATYPE_FUNC
(
PyTensorObject_half
,
DType
::
Float16
());
DATATYPE_FUNC
(
PyTensorObject_float
,
DType
::
Float
());
DATATYPE_FUNC
(
PyTensorObject_double
,
DType
::
Double
());
static
PyObject
*
PyTensorObject_view
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
PyObject
*
shape
=
args
;
if
(
PyTuple_Size
(
args
)
==
1
)
{
PyObject
*
item
=
PyTuple_GetItem
(
args
,
0
);
if
(
!
PyLong_Check
(
item
))
{
shape
=
item
;
}
}
PyObjectPtr
_args
=
PyObjectPtr
(
PyTuple_Pack
(
2
,
self
,
shape
));
PyObject
*
result
=
functional
::
view
(
NULL
,
_args
.
get
(),
kwargs
);
if
(
PyErr_Occurred
())
{
throw
py
::
error_already_set
();
}
return
result
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_view_as
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
auto
tensor
=
PyTensor_Unpack
(
self
);
PyObject
*
other
=
NULL
;
static
const
char
*
keywords
[
2
]
=
{
"other"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"O|:view_as"
,
const_cast
<
char
**>
(
keywords
),
&
other
))
{
return
NULL
;
}
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
View
(
tensor
,
*
PyTensor_Unpack
(
other
)
->
shape
())));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_permute
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
PyObject
*
dims
=
args
;
if
(
PyTuple_Size
(
args
)
==
1
)
{
PyObject
*
item
=
PyTuple_GetItem
(
args
,
0
);
if
(
!
PyLong_Check
(
item
))
{
dims
=
item
;
}
}
PyObjectPtr
_args
=
PyObjectPtr
(
PyTuple_Pack
(
2
,
self
,
dims
));
PyObject
*
result
=
functional
::
permute
(
NULL
,
_args
.
get
(),
kwargs
);
if
(
PyErr_Occurred
())
{
throw
py
::
error_already_set
();
}
return
result
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_transpose
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
auto
tensor
=
PyTensor_Unpack
(
self
);
int
dim0
=
0
;
int
dim1
=
0
;
static
const
char
*
keywords
[
3
]
=
{
"dim0"
,
"dim1"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"ii:transpose"
,
const_cast
<
char
**>
(
keywords
),
&
dim0
,
&
dim1
))
{
return
NULL
;
}
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
Transpose2dim
(
tensor
,
dim0
,
dim1
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_local_to_global
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
auto
tensor
=
PyTensor_Unpack
(
self
);
CHECK_OR_THROW
(
tensor
->
is_local
())
<<
Error
::
RuntimeError
()
<<
"input must be a local tensor"
;
PyObject
*
placement_obj
=
Py_None
;
PyObject
*
sbp_obj
=
Py_None
;
bool
check_meta
=
true
;
static
const
char
*
keywords
[
4
]
=
{
"placement"
,
"sbp"
,
"check_meta"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"|OO$O!:local_to_global"
,
const_cast
<
char
**>
(
keywords
),
&
placement_obj
,
&
sbp_obj
,
&
PyBool_Type
,
&
check_meta
))
{
return
NULL
;
};
CHECK_OR_THROW
(
placement_obj
!=
Py_None
&&
sbp_obj
!=
Py_None
)
<<
Error
::
InvalidValueError
(
"Converting a local tensor to global tensor must have placement and sbp parameters."
);
CHECK_OR_THROW
(
functional
::
PyParallelDescCheck
(
placement_obj
))
<<
Error
::
TypeError
()
<<
"Invalid parameter placement with type "
<<
functional
::
PyStringAsString
(
PyObject_Str
((
PyObject
*
)
Py_TYPE
(
placement_obj
)));
std
::
vector
<
Symbol
<
SbpParallel
>>
sbp
;
if
(
functional
::
PySbpParallelCheck
(
sbp_obj
))
{
sbp
.
emplace_back
(
functional
::
PyUnpackSbpParallel
(
sbp_obj
));
}
else
{
CHECK_OR_THROW
(
functional
::
PySbpParallelSequenceCheck
(
sbp_obj
))
<<
Error
::
TypeError
()
<<
"Invalid parameter sbp with type "
<<
functional
::
PyStringAsString
(
PyObject_Str
((
PyObject
*
)
Py_TYPE
(
sbp_obj
)));
sbp
=
functional
::
PyUnpackSbpParallelSequence
(
sbp_obj
);
}
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
ToConsistent
(
tensor
,
functional
::
PyUnpackParallelDesc
(
placement_obj
),
sbp
,
{},
check_meta
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_global_to_global
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
auto
tensor
=
PyTensor_Unpack
(
self
);
CHECK_OR_THROW
(
tensor
->
is_consistent
())
<<
Error
::
RuntimeError
()
<<
"input must be a global tensor"
;
PyObject
*
placement_obj
=
Py_None
;
PyObject
*
sbp_obj
=
Py_None
;
PyObject
*
grad_sbp_obj
=
Py_None
;
Symbol
<
ParallelDesc
>
placement
;
std
::
vector
<
Symbol
<
SbpParallel
>>
sbp
;
std
::
vector
<
Symbol
<
SbpParallel
>>
grad_sbp
;
bool
check_meta
=
false
;
static
const
char
*
keywords
[
5
]
=
{
"placement"
,
"sbp"
,
"grad_sbp"
,
"check_meta"
,
NULL
};
if
(
!
PyArg_ParseTupleAndKeywords
(
args
,
kwargs
,
"|OO$OO!:global_to_global"
,
const_cast
<
char
**>
(
keywords
),
&
placement_obj
,
&
sbp_obj
,
&
grad_sbp_obj
,
&
PyBool_Type
,
&
check_meta
))
{
return
NULL
;
};
// sbp
CHECK_OR_THROW
(
sbp_obj
==
Py_None
||
functional
::
PySbpParallelCheck
(
sbp_obj
)
||
functional
::
PySbpParallelSequenceCheck
(
sbp_obj
))
<<
Error
::
TypeError
()
<<
"sbp parameter must be type of oneflow.sbp.sbp or list/tuple of oneflow.sbp.sbp"
;
if
(
functional
::
PySbpParallelCheck
(
sbp_obj
))
{
sbp
.
emplace_back
(
functional
::
PyUnpackSbpParallel
(
sbp_obj
));
}
else
if
(
functional
::
PySbpParallelSequenceCheck
(
sbp_obj
))
{
sbp
=
functional
::
PyUnpackSbpParallelSequence
(
sbp_obj
);
}
else
{
for
(
int32_t
i
=
0
;
i
<
ASSERT
(
tensor
->
nd_sbp
())
->
sbp_parallel_size
();
i
++
)
sbp
.
emplace_back
(
ASSERT
(
tensor
->
nd_sbp
())
->
sbp_parallel
(
i
));
}
// placement
CHECK_OR_THROW
(
placement_obj
==
Py_None
||
functional
::
PyParallelDescCheck
(
placement_obj
))
<<
Error
::
TypeError
()
<<
"Invalid parameter placement with type "
<<
functional
::
PyStringAsString
(
PyObject_Str
((
PyObject
*
)
Py_TYPE
(
placement_obj
)));
if
(
placement_obj
==
Py_None
)
{
placement
=
ASSERT
(
tensor
->
parallel_desc
());
}
else
{
placement
=
functional
::
PyUnpackParallelDesc
(
placement_obj
);
}
// grad_sbp
CHECK_OR_THROW
(
grad_sbp_obj
==
Py_None
||
functional
::
PySbpParallelCheck
(
grad_sbp_obj
)
||
functional
::
PySbpParallelSequenceCheck
(
grad_sbp_obj
))
<<
Error
::
TypeError
()
<<
"grad_sbp parameter must be type of oneflow.sbp.sbp or list/tuple of oneflow.sbp.sbp"
;
if
(
functional
::
PySbpParallelCheck
(
grad_sbp_obj
))
{
grad_sbp
.
emplace_back
(
functional
::
PyUnpackSbpParallel
(
grad_sbp_obj
));
}
else
if
(
functional
::
PySbpParallelSequenceCheck
(
grad_sbp_obj
))
{
grad_sbp
=
functional
::
PyUnpackSbpParallelSequence
(
grad_sbp_obj
);
}
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
ToConsistent
(
tensor
,
placement
,
sbp
,
grad_sbp
,
check_meta
)));
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_to_global
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
const
auto
&
tensor
=
PyTensor_Unpack
(
self
);
PyObject
*
result
=
NULL
;
if
(
tensor
->
is_consistent
())
result
=
PyTensorObject_global_to_global
(
self
,
args
,
kwargs
);
else
{
result
=
PyTensorObject_local_to_global
(
self
,
args
,
kwargs
);
}
if
(
PyErr_Occurred
())
{
throw
py
::
error_already_set
();
}
return
result
;
END_HANDLE_ERRORS
}
static
PyObject
*
PyTensorObject_to_local
(
PyObject
*
self
,
PyObject
*
unused
)
{
HANDLE_ERRORS
auto
tensor
=
PyTensor_Unpack
(
self
);
CHECK_OR_THROW
(
tensor
->
is_consistent
())
<<
Error
::
RuntimeError
()
<<
"Expected global tensor for to_local but got local tensor!"
;
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
ConsistentToLocal
(
tensor
)));
END_HANDLE_ERRORS
}
int
PyTensorObject_setitem
(
PyObject
*
self
,
PyObject
*
item
,
PyObject
*
value
)
{
HANDLE_ERRORS
auto
tensor
=
PyTensor_Unpack
(
self
);
std
::
shared_ptr
<
Tensor
>
value_tensor
;
CHECK_OR_THROW
(
functional
::
PyTensorIndexCheck
(
item
))
<<
Error
::
TypeError
()
<<
"tensor_setitem(): argument 'index' must be index, not "
<<
functional
::
PyStringAsString
(
PyObject_Str
((
PyObject
*
)
Py_TYPE
(
item
)));
CHECK_OR_THROW
(
functional
::
PyScalarCheck
(
value
)
||
PyTensor_Check
(
value
))
<<
Error
::
TypeError
()
<<
"tensor_setitem(): argument 'value' must be tensor or scalar, not "
<<
functional
::
PyStringAsString
(
PyObject_Str
((
PyObject
*
)
Py_TYPE
(
value
)));
if
(
tensor
->
is_consistent
())
{
Symbol
<
ParallelDesc
>
placement
=
ASSERT
(
tensor
->
parallel_desc
());
auto
ndsbp
=
ASSERT
(
tensor
->
nd_sbp
());
std
::
vector
<
Symbol
<
SbpParallel
>>
sbp
(
ndsbp
->
sbp_parallel_size
(),
ASSERT
(
MakeBroadcastSbpParallel
()));
if
(
functional
::
PyScalarCheck
(
value
))
{
Scalar
value_scalar
=
functional
::
PyUnpackScalar
(
value
);
value_tensor
=
ASSERT_PTR
(
functional
::
ConsistentConstant
({
1
},
value_scalar
,
tensor
->
dtype
(),
placement
,
sbp
));
}
else
{
value_tensor
=
PyTensor_Unpack
(
value
);
CHECK_OR_THROW
(
value_tensor
->
is_consistent
())
<<
Error
::
RuntimeError
()
<<
"tensor_setitem(): value must be a global tensor when self is global"
;
value_tensor
=
ASSERT_PTR
(
functional
::
ToConsistent
(
value_tensor
,
placement
,
sbp
,
{},
true
));
}
}
else
{
if
(
functional
::
PyScalarCheck
(
value
))
{
Scalar
value_scalar
=
functional
::
PyUnpackScalar
(
value
);
value_tensor
=
ASSERT_PTR
(
functional
::
Constant
({
1
},
value_scalar
,
tensor
->
dtype
(),
ASSERT
(
tensor
->
device
())));
}
else
{
value_tensor
=
PyTensor_Unpack
(
value
);
CHECK_OR_THROW
(
value_tensor
->
is_local
())
<<
Error
::
RuntimeError
()
<<
"tensor_setitem(): value must be a local tensor when self is local"
;
Optional
<
Symbol
<
Device
>>
device
=
ASSERT
(
tensor
->
device
());
value_tensor
=
ASSERT_PTR
(
functional
::
To
(
value_tensor
,
device
,
value_tensor
->
dtype
(),
false
));
}
}
ASSERT
(
functional
::
TensorSetItem
(
tensor
,
functional
::
PyUnpackTensorIndex
(
item
),
value_tensor
));
return
0
;
END_HANDLE_ERRORS_RET
(
-
1
)
}
PyMethodDef
PyTensorObject_extra_methods
[]
=
{
{
"byte"
,
PyTensorObject_byte
,
METH_NOARGS
,
NULL
},
{
"size"
,
(
PyCFunction
)
PyTensorObject_size
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"argmax"
,
(
PyCFunction
)
PyTensorObject_argmax
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"argmin"
,
(
PyCFunction
)
PyTensorObject_argmin
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"amin"
,
(
PyCFunction
)
PyTensorObject_amin
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"dim"
,
PyTensorObject_dim
,
METH_NOARGS
,
NULL
},
{
"ndimension"
,
PyTensorObject_dim
,
METH_NOARGS
,
NULL
},
{
"nelement"
,
PyTensorObject_nelement
,
METH_NOARGS
,
NULL
},
{
"numel"
,
PyTensorObject_nelement
,
METH_NOARGS
,
NULL
},
{
"element_size"
,
PyTensorObject_element_size
,
METH_NOARGS
,
NULL
},
{
"get_device"
,
PyTensorObject_get_device
,
METH_NOARGS
,
NULL
},
{
"cast"
,
(
PyCFunction
)
PyTensorObject_cast
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"diag"
,
(
PyCFunction
)
PyTensorObject_diag
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"diagonal"
,
(
PyCFunction
)
PyTensorObject_diagonal
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"addcmul"
,
(
PyCFunction
)
PyTensorObject_addcmul
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"addcmul_"
,
(
PyCFunction
)
PyTensorObject_addcmul_
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"matmul"
,
(
PyCFunction
)
PyTensorObject_matmul
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"int"
,
PyTensorObject_int
,
METH_NOARGS
,
NULL
},
{
"long"
,
PyTensorObject_long
,
METH_NOARGS
,
NULL
},
{
"half"
,
PyTensorObject_half
,
METH_NOARGS
,
NULL
},
{
"float"
,
PyTensorObject_float
,
METH_NOARGS
,
NULL
},
{
"double"
,
PyTensorObject_double
,
METH_NOARGS
,
NULL
},
{
"local_to_global"
,
(
PyCFunction
)
PyTensorObject_local_to_global
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"global_to_global"
,
(
PyCFunction
)
PyTensorObject_global_to_global
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"to_local"
,
PyTensorObject_to_local
,
METH_NOARGS
,
NULL
},
{
"to_global"
,
(
PyCFunction
)
PyTensorObject_to_global
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"cpu"
,
PyTensorObject_cpu
,
METH_NOARGS
,
NULL
},
{
"cuda"
,
(
PyCFunction
)
PyTensorObject_cuda
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"var"
,
(
PyCFunction
)
PyTensorObject_var
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"std"
,
(
PyCFunction
)
PyTensorObject_std
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"softplus"
,
(
PyCFunction
)
PyTensorObject_softplus
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"relu"
,
PyTensorObject_relu
,
METH_NOARGS
,
NULL
},
{
"relu_"
,
PyTensorObject_relu_
,
METH_NOARGS
,
NULL
},
{
"all"
,
(
PyCFunction
)
PyTensorObject_all
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"any"
,
(
PyCFunction
)
PyTensorObject_any
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"sum"
,
(
PyCFunction
)
PyTensorObject_sum
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"mean"
,
(
PyCFunction
)
PyTensorObject_mean
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
// macro DIRECT_PASS_FUNC
{
"floor_divide"
,
(
PyCFunction
)
PyTensorObject_floor_divide
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"atan2"
,
(
PyCFunction
)
PyTensorObject_atan2
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"gt"
,
(
PyCFunction
)
PyTensorObject_gt
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"ge"
,
(
PyCFunction
)
PyTensorObject_ge
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"div"
,
(
PyCFunction
)
PyTensorObject_div
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"div_"
,
(
PyCFunction
)
PyTensorObject_div_
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"mul"
,
(
PyCFunction
)
PyTensorObject_mul
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"mul_"
,
(
PyCFunction
)
PyTensorObject_mul_
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"fmod"
,
(
PyCFunction
)
PyTensorObject_fmod
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"logical_and"
,
(
PyCFunction
)
PyTensorObject_logical_and
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"logical_or"
,
(
PyCFunction
)
PyTensorObject_logical_or
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"logical_xor"
,
(
PyCFunction
)
PyTensorObject_logical_xor
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"bmm"
,
(
PyCFunction
)
PyTensorObject_bmm
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"ne"
,
(
PyCFunction
)
PyTensorObject_ne
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"lt"
,
(
PyCFunction
)
PyTensorObject_lt
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"le"
,
(
PyCFunction
)
PyTensorObject_le
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"clip"
,
(
PyCFunction
)
PyTensorObject_clip
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"clip_"
,
(
PyCFunction
)
PyTensorObject_clip_
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"clamp"
,
(
PyCFunction
)
PyTensorObject_clamp
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"clamp_"
,
(
PyCFunction
)
PyTensorObject_clamp_
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"flatten"
,
(
PyCFunction
)
PyTensorObject_flatten
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"in_top_k"
,
(
PyCFunction
)
PyTensorObject_in_top_k
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"index_select"
,
(
PyCFunction
)
PyTensorObject_index_select
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"maximum"
,
(
PyCFunction
)
PyTensorObject_maximum
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"minimum"
,
(
PyCFunction
)
PyTensorObject_minimum
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"tril"
,
(
PyCFunction
)
PyTensorObject_tril
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"triu"
,
(
PyCFunction
)
PyTensorObject_triu
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"softmax"
,
(
PyCFunction
)
PyTensorObject_softmax
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"log_softmax"
,
(
PyCFunction
)
PyTensorObject_log_softmax
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"roll"
,
(
PyCFunction
)
PyTensorObject_roll
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"unbind"
,
(
PyCFunction
)
PyTensorObject_unbind
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"squeeze"
,
(
PyCFunction
)
PyTensorObject_squeeze
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"swapaxes"
,
(
PyCFunction
)
PyTensorObject_swapaxes
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"amax"
,
(
PyCFunction
)
PyTensorObject_amax
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"swapdims"
,
(
PyCFunction
)
PyTensorObject_swapdims
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"unfold"
,
(
PyCFunction
)
PyTensorObject_unfold
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"unsqueeze"
,
(
PyCFunction
)
PyTensorObject_unsqueeze
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"max"
,
(
PyCFunction
)
PyTensorObject_max
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"min"
,
(
PyCFunction
)
PyTensorObject_min
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"median"
,
(
PyCFunction
)
PyTensorObject_median
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"pow"
,
(
PyCFunction
)
PyTensorObject_pow
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"chunk"
,
(
PyCFunction
)
PyTensorObject_chunk
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"narrow"
,
(
PyCFunction
)
PyTensorObject_narrow
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"masked_fill"
,
(
PyCFunction
)
PyTensorObject_masked_fill
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"dot"
,
(
PyCFunction
)
PyTensorObject_dot
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
// macro UNARY_METHOD
{
"abs"
,
PyTensorObject_abs
,
METH_NOARGS
,
NULL
},
{
"exp"
,
PyTensorObject_exp
,
METH_NOARGS
,
NULL
},
{
"floor"
,
PyTensorObject_floor
,
METH_NOARGS
,
NULL
},
{
"floor_"
,
PyTensorObject_floor_
,
METH_NOARGS
,
NULL
},
{
"acos"
,
PyTensorObject_acos
,
METH_NOARGS
,
NULL
},
{
"arccos"
,
PyTensorObject_acos
,
METH_NOARGS
,
NULL
},
{
"acosh"
,
PyTensorObject_acosh
,
METH_NOARGS
,
NULL
},
{
"arccosh"
,
PyTensorObject_acosh
,
METH_NOARGS
,
NULL
},
{
"atanh"
,
PyTensorObject_atanh
,
METH_NOARGS
,
NULL
},
{
"arctanh"
,
PyTensorObject_atanh
,
METH_NOARGS
,
NULL
},
{
"sign"
,
PyTensorObject_sign
,
METH_NOARGS
,
NULL
},
{
"sinh"
,
PyTensorObject_sinh
,
METH_NOARGS
,
NULL
},
{
"tan"
,
PyTensorObject_tan
,
METH_NOARGS
,
NULL
},
{
"gelu"
,
PyTensorObject_gelu
,
METH_NOARGS
,
NULL
},
{
"mish"
,
PyTensorObject_mish
,
METH_NOARGS
,
NULL
},
{
"negative"
,
PyTensorObject_negative
,
METH_NOARGS
,
NULL
},
{
"neg"
,
PyTensorObject_negative
,
METH_NOARGS
,
NULL
},
{
"sigmoid"
,
PyTensorObject_sigmoid
,
METH_NOARGS
,
NULL
},
{
"tanh"
,
PyTensorObject_tanh
,
METH_NOARGS
,
NULL
},
{
"silu"
,
PyTensorObject_silu
,
METH_NOARGS
,
NULL
},
{
"selu"
,
PyTensorObject_selu
,
METH_NOARGS
,
NULL
},
{
"softsign"
,
PyTensorObject_softsign
,
METH_NOARGS
,
NULL
},
{
"log1p"
,
PyTensorObject_log1p
,
METH_NOARGS
,
NULL
},
{
"log2"
,
PyTensorObject_log2
,
METH_NOARGS
,
NULL
},
{
"reciprocal"
,
PyTensorObject_reciprocal
,
METH_NOARGS
,
NULL
},
{
"asin"
,
PyTensorObject_asin
,
METH_NOARGS
,
NULL
},
{
"arcsin"
,
PyTensorObject_asin
,
METH_NOARGS
,
NULL
},
{
"asinh"
,
PyTensorObject_asinh
,
METH_NOARGS
,
NULL
},
{
"arcsinh"
,
PyTensorObject_asinh
,
METH_NOARGS
,
NULL
},
{
"atan"
,
PyTensorObject_atan
,
METH_NOARGS
,
NULL
},
{
"arctan"
,
PyTensorObject_atan
,
METH_NOARGS
,
NULL
},
{
"ceil"
,
PyTensorObject_ceil
,
METH_NOARGS
,
NULL
},
{
"cos"
,
PyTensorObject_cos
,
METH_NOARGS
,
NULL
},
{
"cosh"
,
PyTensorObject_cosh
,
METH_NOARGS
,
NULL
},
{
"erf"
,
PyTensorObject_erf
,
METH_NOARGS
,
NULL
},
{
"erfc"
,
PyTensorObject_erfc
,
METH_NOARGS
,
NULL
},
{
"erfinv"
,
PyTensorObject_erfinv
,
METH_NOARGS
,
NULL
},
{
"erfinv_"
,
PyTensorObject_erfinv_
,
METH_NOARGS
,
NULL
},
{
"expm1"
,
PyTensorObject_expm1
,
METH_NOARGS
,
NULL
},
{
"log"
,
PyTensorObject_log
,
METH_NOARGS
,
NULL
},
{
"rsqrt"
,
PyTensorObject_rsqrt
,
METH_NOARGS
,
NULL
},
{
"sqrt"
,
PyTensorObject_sqrt
,
METH_NOARGS
,
NULL
},
{
"square"
,
PyTensorObject_square
,
METH_NOARGS
,
NULL
},
{
"round"
,
PyTensorObject_round
,
METH_NOARGS
,
NULL
},
{
"t"
,
PyTensorObject_t
,
METH_NOARGS
,
NULL
},
{
"sin"
,
PyTensorObject_sin
,
METH_NOARGS
,
NULL
},
{
"sin_"
,
PyTensorObject_sin_
,
METH_NOARGS
,
NULL
},
{
"isnan"
,
PyTensorObject_isnan
,
METH_NOARGS
,
NULL
},
{
"isinf"
,
PyTensorObject_isinf
,
METH_NOARGS
,
NULL
},
{
"logical_not"
,
PyTensorObject_logical_not
,
METH_NOARGS
,
NULL
},
{
"floor"
,
PyTensorObject_floor
,
METH_NOARGS
,
NULL
},
{
"floor_"
,
PyTensorObject_floor_
,
METH_NOARGS
,
NULL
},
{
"reshape"
,
(
PyCFunction
)
PyTensorObject_reshape
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"reshape_as"
,
(
PyCFunction
)
PyTensorObject_reshape_as
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"view"
,
(
PyCFunction
)
PyTensorObject_view
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"view_as"
,
(
PyCFunction
)
PyTensorObject_view_as
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"permute"
,
(
PyCFunction
)
PyTensorObject_permute
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
"transpose"
,
(
PyCFunction
)
PyTensorObject_transpose
,
METH_VARARGS
|
METH_KEYWORDS
,
NULL
},
{
NULL
},
};
// tp_richcompare
PyObject
*
PyTensorObject_richcompare
(
PyObject
*
self
,
PyObject
*
other
,
int
op
)
{
PyObjectPtr
tuple
(
PyTuple_Pack
(
2
,
self
,
other
));
switch
(
op
)
{
case
Py_LT
:
return
functional
::
less
(
NULL
,
tuple
.
get
(),
NULL
);
case
Py_LE
:
return
functional
::
less_equal
(
NULL
,
tuple
.
get
(),
NULL
);
case
Py_EQ
:
{
if
(
self
==
Py_None
||
other
==
Py_None
)
return
Py_False
;
return
functional
::
equal
(
NULL
,
tuple
.
get
(),
NULL
);
}
case
Py_NE
:
return
functional
::
not_equal
(
NULL
,
tuple
.
get
(),
NULL
);
case
Py_GT
:
return
functional
::
greater
(
NULL
,
tuple
.
get
(),
NULL
);
case
Py_GE
:
return
functional
::
greater_equal
(
NULL
,
tuple
.
get
(),
NULL
);
}
return
NULL
;
}
}
// namespace one
}
// namespace oneflow
#undef ASSERT
#undef ASSERT_PTR
Prev
1
…
9
10
11
12
13
14
15
16
17
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment