Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
21d47d0e
Commit
21d47d0e
authored
Oct 24, 2022
by
yuguo
Browse files
Oneflow 0.8 for DCU
parents
Changes
556
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3226 additions
and
0 deletions
+3226
-0
oneflow/api/python/framework/tensor_tuple.cpp
oneflow/api/python/framework/tensor_tuple.cpp
+84
-0
oneflow/api/python/framework/tensortype.cpp
oneflow/api/python/framework/tensortype.cpp
+185
-0
oneflow/api/python/framework/tensortype.h
oneflow/api/python/framework/tensortype.h
+49
-0
oneflow/api/python/framework/variable_tensor_mgr.cpp
oneflow/api/python/framework/variable_tensor_mgr.cpp
+32
-0
oneflow/api/python/functional/common.cpp
oneflow/api/python/functional/common.cpp
+304
-0
oneflow/api/python/functional/common.h
oneflow/api/python/functional/common.h
+190
-0
oneflow/api/python/functional/dispatch_stateful_ops.cpp
oneflow/api/python/functional/dispatch_stateful_ops.cpp
+542
-0
oneflow/api/python/functional/dispatch_stateful_ops.yaml
oneflow/api/python/functional/dispatch_stateful_ops.yaml
+157
-0
oneflow/api/python/functional/function_def.h
oneflow/api/python/functional/function_def.h
+77
-0
oneflow/api/python/functional/indexing.cpp
oneflow/api/python/functional/indexing.cpp
+222
-0
oneflow/api/python/functional/indexing.h
oneflow/api/python/functional/indexing.h
+44
-0
oneflow/api/python/functional/python_arg.cpp
oneflow/api/python/functional/python_arg.cpp
+247
-0
oneflow/api/python/functional/python_arg.h
oneflow/api/python/functional/python_arg.h
+123
-0
oneflow/api/python/functional/python_arg_parser.cpp
oneflow/api/python/functional/python_arg_parser.cpp
+124
-0
oneflow/api/python/functional/python_arg_parser.h
oneflow/api/python/functional/python_arg_parser.h
+108
-0
oneflow/api/python/functional/python_frame.h
oneflow/api/python/functional/python_frame.h
+90
-0
oneflow/api/python/functional/tensor_api.cpp
oneflow/api/python/functional/tensor_api.cpp
+322
-0
oneflow/api/python/functional/tensor_api.yaml
oneflow/api/python/functional/tensor_api.yaml
+43
-0
oneflow/api/python/functional/value_types.cpp
oneflow/api/python/functional/value_types.cpp
+92
-0
oneflow/api/python/functional/value_types.h
oneflow/api/python/functional/value_types.h
+191
-0
No files found.
Too many changes to show.
To preserve performance only
556 of 556+
files are displayed.
Plain diff
Email patch
oneflow/api/python/framework/tensor_tuple.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/tensor.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
one
{
namespace
{
struct
TensorTupleUtil
final
{
static
std
::
string
ToString
(
const
TensorTuple
&
tensor_tuple
)
{
std
::
stringstream
ss
;
int32_t
idx
=
0
;
ss
<<
"TensorTuple("
;
for
(
const
std
::
shared_ptr
<
Tensor
>&
tensor
:
tensor_tuple
)
{
ss
<<
tensor
;
if
(
++
idx
!=
tensor_tuple
.
size
()
||
tensor_tuple
.
size
()
==
1
)
{
ss
<<
", "
;
}
}
ss
<<
")"
;
return
ss
.
str
();
}
static
void
MergeFrom
(
std
::
shared_ptr
<
TensorTuple
>&
tensor_tuple
,
const
TensorTuple
&
other
)
{
for
(
const
auto
&
tensor
:
other
)
{
tensor_tuple
->
emplace_back
(
tensor
);
}
}
static
void
AppendTensor
(
std
::
shared_ptr
<
TensorTuple
>&
tensor_tuple
,
const
std
::
shared_ptr
<
Tensor
>&
tensor
)
{
tensor_tuple
->
emplace_back
(
tensor
);
}
};
}
// namespace
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
py
::
class_
<
TensorTuple
,
std
::
shared_ptr
<
TensorTuple
>>
(
m
,
"TensorTuple"
)
.
def
(
py
::
init
([]()
{
return
std
::
make_shared
<
TensorTuple
>
();
}))
.
def
(
py
::
init
([](
const
std
::
shared_ptr
<
TensorTuple
>&
other
)
{
return
other
;
}))
.
def
(
py
::
init
([](
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
list
)
{
auto
tensor_tuple
=
std
::
make_shared
<
TensorTuple
>
();
for
(
const
auto
&
t
:
list
)
{
tensor_tuple
->
emplace_back
(
t
);
}
return
tensor_tuple
;
}))
.
def
(
"__str__"
,
&
TensorTupleUtil
::
ToString
)
.
def
(
"__repr__"
,
&
TensorTupleUtil
::
ToString
)
.
def
(
"__getitem__"
,
[](
const
TensorTuple
&
tensor_tuple
,
int
idx
)
{
return
tensor_tuple
.
at
(
idx
);
})
.
def
(
"__setitem__"
,
[](
std
::
shared_ptr
<
TensorTuple
>&
tensor_tuple
,
int
idx
,
const
std
::
shared_ptr
<
Tensor
>&
tensor
)
{
tensor_tuple
->
at
(
idx
)
=
tensor
;
})
.
def
(
"__iter__"
,
[](
const
TensorTuple
&
tensor_tuple
)
{
return
py
::
make_iterator
(
tensor_tuple
.
begin
(),
tensor_tuple
.
end
());
},
py
::
keep_alive
<
0
,
1
>
())
.
def
(
"__len__"
,
[](
const
TensorTuple
&
tensor_tuple
)
{
return
tensor_tuple
.
size
();
})
.
def
(
"merge_from"
,
&
TensorTupleUtil
::
MergeFrom
)
.
def
(
"append"
,
&
TensorTupleUtil
::
AppendTensor
);
}
}
// namespace one
}
// namespace oneflow
oneflow/api/python/framework/tensortype.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <Python.h>
#include <pybind11/pybind11.h>
#include "oneflow/api/python/framework/tensor.h"
#include "oneflow/api/python/framework/tensortype.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/functional/tensor_api.yaml.pybind.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/api/python/exception/exception.h"
namespace
oneflow
{
namespace
one
{
#define ASSERT(x) (x).GetOrThrow()
#define ASSERT_PTR(x) (x).GetPtrOrThrow()
static
PyTypeObject
PyTensorTypeMetaClass
{
PyVarObject_HEAD_INIT
(
NULL
,
0
)
"oneflow.tensortype"
,
// tp_name
sizeof
(
PyTypeObject
),
// tp_basicsize
};
static
PyTypeObject
PyTensorTypeTemplate
{
PyVarObject_HEAD_INIT
(
&
PyTensorTypeMetaClass
,
0
)
NULL
,
// tp_name
sizeof
(
PyTensorType
),
// tp_basicsize
};
static
std
::
vector
<
PyTensorType
*>
tensor_types
;
static
std
::
vector
<
std
::
pair
<
const
Symbol
<
DType
>&
,
std
::
string
>>
all_data_types
=
{
{
DType
::
Float
(),
"FloatTensor"
},
{
DType
::
Double
(),
"DoubleTensor"
},
{
DType
::
Int8
(),
"CharTensor"
},
{
DType
::
Int32
(),
"IntTensor"
},
{
DType
::
Int64
(),
"LongTensor"
},
{
DType
::
UInt8
(),
"ByteTensor"
},
{
DType
::
Float16
(),
"HalfTensor"
},
{
DType
::
BFloat16
(),
"BFloat16Tensor"
},
{
DType
::
Bool
(),
"BoolTensor"
},
};
static
std
::
vector
<
std
::
pair
<
DeviceType
,
std
::
string
>>
all_device_types
=
{
{
kCPU
,
"oneflow"
},
{
kCUDA
,
"oneflow.cuda"
},
};
static
PyObject
*
PyTensorTypeMetaCls_call
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
HANDLE_ERRORS
auto
*
tensor
=
functional
::
_legacy_tensor_ctor
(
NULL
,
args
,
kwargs
);
if
(
PyErr_Occurred
())
{
throw
py
::
error_already_set
();
}
if
(
!
TRY
(
DeviceTag4DeviceType
(
PyTensorType_UnpackDevice
(
self
))).
IsOk
())
return
PyErr_Format
(
PyExc_ValueError
,
"invalid device"
);
Optional
<
std
::
string
>
device
=
ASSERT
(
DeviceTag4DeviceType
(
PyTensorType_UnpackDevice
(
self
)));
const
auto
&
data_type
=
PyTensorType_UnpackDType
(
self
);
return
PyTensor_New
(
ASSERT_PTR
(
functional
::
To
(
PyTensor_Unpack
(
tensor
),
device
,
data_type
,
/*copy=*/
false
)));
END_HANDLE_ERRORS
};
PyObject
*
PyTensorType_FromString
(
const
std
::
string
&
tensortype
)
{
auto
it
=
std
::
find_if
(
tensor_types
.
begin
(),
tensor_types
.
end
(),
[
tensortype
](
PyTensorType
*
type
)
{
return
std
::
string
(
type
->
name
)
==
tensortype
;
});
if
(
it
==
tensor_types
.
end
())
{
PyErr_Format
(
PyExc_ValueError
,
"invalid type: %s"
,
tensortype
.
data
());
throw
py
::
error_already_set
();
}
return
(
PyObject
*
)(
*
it
);
}
static
const
char
*
get_doc
(
PyTensorType
*
tensortype
)
{
// all tensortype docs
static
std
::
vector
<
std
::
string
>
tensortype_doc
;
std
::
string
dtype
=
tensortype
->
dtype
->
name
();
std
::
string
doc
=
""
;
if
(
!
TRY
(
DeviceTag4DeviceType
(
tensortype
->
devicetype
)).
IsOk
())
doc
=
"The tensortype "
+
std
::
string
(
tensortype
->
name
)
+
" is not available."
;
else
{
std
::
string
device
=
ASSERT
(
DeviceTag4DeviceType
(
tensortype
->
devicetype
));
doc
=
"Creates a Tensor with the dtype of "
+
dtype
+
" and the device on "
+
device
+
", it has the same parameters as :func:`oneflow.Tensor`"
;
}
tensortype_doc
.
emplace_back
(
doc
);
return
tensortype_doc
.
back
().
data
();
}
static
void
init_tensortype_metaclass
(
PyTypeObject
*
metaclass
)
{
metaclass
->
tp_flags
=
Py_TPFLAGS_DEFAULT
|
Py_TPFLAGS_BASETYPE
;
metaclass
->
tp_base
=
&
PyType_Type
;
metaclass
->
tp_call
=
PyTensorTypeMetaCls_call
;
if
(
PyType_Ready
(
metaclass
)
<
0
)
{
return
;
}
}
static
void
init_tensortype
(
PyTypeObject
*
type
,
PyTypeObject
&
type_template
,
const
char
*
name
,
const
char
*
doc
)
{
memcpy
(
type
,
&
type_template
,
sizeof
(
PyTypeObject
));
type
->
tp_name
=
name
;
type
->
tp_doc
=
doc
;
type
->
tp_flags
=
Py_TPFLAGS_DEFAULT
;
if
(
PyType_Ready
(
type
)
<
0
)
{
THROW
(
RuntimeError
)
<<
"tensortype initialization failed"
;
}
}
static
void
generalize_tensor_types
()
{
init_tensortype_metaclass
(
&
PyTensorTypeMetaClass
);
for
(
const
auto
&
devicetype
:
all_device_types
)
{
for
(
const
auto
&
dtype
:
all_data_types
)
{
PyTensorType
*
tensortype
=
new
PyTensorType
();
// set name
std
::
string
name
=
devicetype
.
second
+
"."
+
dtype
.
second
;
size_t
n
=
sizeof
(
tensortype
->
name
);
strncpy
(
tensortype
->
name
,
name
.
c_str
(),
n
-
1
);
tensortype
->
name
[
n
-
1
]
=
'\0'
;
// set type
tensortype
->
dtype
=
dtype
.
first
;
tensortype
->
devicetype
=
devicetype
.
first
;
tensortype
->
is_cuda
=
tensortype
->
devicetype
==
DeviceType
::
kCUDA
;
tensor_types
.
push_back
(
tensortype
);
const
char
*
doc
=
get_doc
(
tensortype
);
init_tensortype
(
&
tensortype
->
py_type
,
PyTensorTypeTemplate
,
tensortype
->
name
,
doc
);
}
}
}
bool
PyTensorType_Check
(
PyObject
*
obj
)
{
return
PyObject_TypeCheck
(
obj
,
&
PyTensorTypeMetaClass
);
}
PyObject
*
PyTensorType_FromDTypeAndDeviceType
(
Symbol
<
DType
>
dtype
,
DeviceType
device
)
{
auto
it
=
std
::
find_if
(
tensor_types
.
begin
(),
tensor_types
.
end
(),
[
dtype
,
device
](
PyTensorType
*
x
)
{
return
(
x
->
dtype
==
dtype
)
&&
(
x
->
devicetype
==
device
);
});
if
(
it
==
tensor_types
.
end
())
{
if
(
!
TRY
(
DeviceTag4DeviceType
(
device
)).
IsOk
())
return
PyErr_Format
(
PyExc_ValueError
,
"unsupported device"
);
return
PyErr_Format
(
PyExc_ValueError
,
"unsupported data type (%s) or device (%s)"
,
dtype
->
name
().
c_str
(),
ASSERT
(
DeviceTag4DeviceType
(
device
)).
c_str
());
}
return
(
PyObject
*
)(
*
it
);
};
}
// namespace one
}
// namespace oneflow
#undef ASSERT
using
namespace
oneflow
::
one
;
ONEFLOW_API_PYBIND11_MODULE
(
"_C"
,
m
)
{
static
std
::
string
oneflow_prefix
=
"oneflow."
;
generalize_tensor_types
();
for
(
PyTensorType
*
tensortype
:
tensor_types
)
{
Py_INCREF
(
tensortype
);
std
::
string
name
=
std
::
string
(
tensortype
->
name
);
size_t
idx
=
name
.
rfind
(
'.'
);
std
::
string
type_name
=
name
.
substr
(
idx
+
1
);
name
=
name
.
substr
(
0
,
idx
);
std
::
string
module_name
=
name
.
size
()
>
oneflow_prefix
.
size
()
?
name
.
substr
(
oneflow_prefix
.
size
())
:
""
;
auto
module
=
m
;
if
(
!
module_name
.
empty
())
{
module
=
m
.
def_submodule
(
module_name
.
data
());
}
if
(
tensortype
&&
PyModule_AddObject
(
module
.
ptr
(),
type_name
.
c_str
(),
(
PyObject
*
)
tensortype
)
<
0
)
{
return
;
}
}
}
oneflow/api/python/framework/tensortype.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_TENSORTYPE_H_
#define ONEFLOW_API_PYTHON_FRAMEWORK_TENSORTYPE_H_
#include <Python.h>
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/device.h"
namespace
oneflow
{
namespace
one
{
typedef
struct
{
PyTypeObject
py_type
;
char
name
[
64
];
bool
is_cuda
;
Symbol
<
DType
>
dtype
;
DeviceType
devicetype
;
}
PyTensorType
;
bool
PyTensorType_Check
(
PyObject
*
);
inline
DeviceType
PyTensorType_UnpackDevice
(
PyObject
*
self
)
{
return
((
PyTensorType
*
)
self
)
->
devicetype
;
}
inline
Symbol
<
DType
>
PyTensorType_UnpackDType
(
PyObject
*
self
)
{
return
((
PyTensorType
*
)
self
)
->
dtype
;
}
PyObject
*
PyTensorType_FromDTypeAndDeviceType
(
Symbol
<
DType
>
,
DeviceType
);
PyObject
*
PyTensorType_FromString
(
const
std
::
string
&
);
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_FRAMEWORK_TENSORTYPE_H_
oneflow/api/python/framework/variable_tensor_mgr.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <tuple>
#include "oneflow/api/common/variable_tensor_mgr.h"
#include "oneflow/api/python/of_api_registry.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
ONEFLOW_API_PYBIND11_MODULE
(
""
,
m
)
{
m
.
def
(
"FillVariableTensorMgr"
,
&
FillVariableTensorMgr
);
m
.
def
(
"DumpVariableTensorMgr"
,
&
DumpVariableTensorMgr
);
m
.
def
(
"ClearVariableTensorMgr"
,
&
ClearVariableTensorMgr
);
}
}
// namespace oneflow
oneflow/api/python/functional/common.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/functional/common.h"
#include <object.h>
#include <string>
#include "oneflow/api/python/functional/indexing.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/core/common/just.h"
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/functional/tensor_index.h"
namespace
oneflow
{
namespace
one
{
namespace
functional
{
bool
PySequenceCheck
(
PyObject
*
obj
,
const
std
::
function
<
bool
(
PyObject
*
)
>&
item_check
)
{
bool
is_tuple
=
PyTuple_Check
(
obj
);
if
(
!
is_tuple
&&
!
PyList_Check
(
obj
))
{
return
false
;
}
size_t
size
=
is_tuple
?
PyTuple_GET_SIZE
(
obj
)
:
PyList_GET_SIZE
(
obj
);
if
(
size
==
0
)
{
return
true
;
}
PyObject
*
item
=
is_tuple
?
PyTuple_GET_ITEM
(
obj
,
0
)
:
PyList_GET_ITEM
(
obj
,
0
);
return
item_check
(
item
);
}
bool
PyLongSequenceCheck
(
PyObject
*
obj
)
{
return
PySequenceCheck
(
obj
,
[](
PyObject
*
item
)
{
return
PyLong_Check
(
item
);
});
}
bool
PyFloatSquenceCheck
(
PyObject
*
obj
)
{
return
PySequenceCheck
(
obj
,
[](
PyObject
*
item
)
{
return
PyFloat_Check
(
item
)
||
PyLong_Check
(
item
);
});
}
bool
PyStringCheck
(
PyObject
*
obj
)
{
return
PyBytes_Check
(
obj
)
||
PyUnicode_Check
(
obj
);
}
bool
PyStringSequenceCheck
(
PyObject
*
obj
)
{
return
PySequenceCheck
(
obj
,
[](
PyObject
*
item
)
{
return
PyStringCheck
(
item
);
});
}
std
::
string
PyStringAsString
(
PyObject
*
obj
)
{
PyObject
*
bytes
=
PyUnicode_AsEncodedString
(
obj
,
"utf-8"
,
"~E~"
);
std
::
string
str
=
PyBytes_AS_STRING
(
bytes
);
Py_XDECREF
(
bytes
);
return
str
;
}
std
::
string
PyObjectToReprStr
(
PyObject
*
obj
)
{
PyObject
*
repr_obj
=
PyObject_Repr
(
obj
);
std
::
string
str
=
PyStringAsString
(
repr_obj
);
Py_XDECREF
(
repr_obj
);
return
str
;
}
// Tensor list
bool
PyTensorSequenceCheck
(
PyObject
*
obj
)
{
return
PySequenceCheck
(
obj
,
[](
PyObject
*
item
)
{
return
PyTensor_Check
(
item
);
});
}
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
PyUnpackTensorSequence
(
PyObject
*
obj
)
{
return
PyUnpackSequence
<
std
::
shared_ptr
<
Tensor
>>
(
obj
,
[](
PyObject
*
item
)
{
return
PyTensor_Unpack
(
item
);
});
}
// TensorTuple
bool
PyTensorTupleCheck
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
py
::
isinstance
<
TensorTuple
>
(
handle
);
}
std
::
shared_ptr
<
TensorTuple
>
PyUnpackTensorTuple
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
py
::
cast
<
std
::
shared_ptr
<
TensorTuple
>>
(
handle
);
}
// Scalar
bool
PyScalarCheck
(
PyObject
*
obj
)
{
return
PyLong_Check
(
obj
)
||
PyFloat_Check
(
obj
);
}
Scalar
PyUnpackScalar
(
PyObject
*
obj
)
{
if
(
PyBool_Check
(
obj
))
{
return
obj
==
Py_True
;
}
else
if
(
PyLong_Check
(
obj
))
{
return
static_cast
<
int64_t
>
(
PyLong_AsLongLong
(
obj
));
}
else
if
(
PyFloat_Check
(
obj
))
{
return
PyFloat_AsDouble
(
obj
);
}
THROW
(
RuntimeError
)
<<
"The object is not scalar, but is "
<<
Py_TYPE
(
obj
)
->
tp_name
;
return
0
;
}
// DType
bool
PyDTypeCheck
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
py
::
isinstance
<
Symbol
<
DType
>>
(
handle
);
}
Symbol
<
DType
>
PyUnpackDType
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
*
py
::
cast
<
Symbol
<
DType
>*>
(
handle
);
}
// DType list
bool
PyDTypeSequenceCheck
(
PyObject
*
obj
)
{
return
PySequenceCheck
(
obj
,
[](
PyObject
*
item
)
{
return
PyDTypeCheck
(
item
);
});
}
std
::
vector
<
Symbol
<
DType
>>
PyUnpackDTypeSequence
(
PyObject
*
obj
)
{
return
PyUnpackSequence
<
Symbol
<
DType
>>
(
obj
,
[](
PyObject
*
item
)
{
return
PyUnpackDType
(
item
);
});
}
// Shape list
bool
PyShapeSequenceCheck
(
PyObject
*
obj
)
{
return
PySequenceCheck
(
obj
,
[](
PyObject
*
item
)
{
return
PyLongSequenceCheck
(
item
);
});
}
std
::
vector
<
Shape
>
PyUnpackShapeSequence
(
PyObject
*
obj
)
{
return
PyUnpackSequence
<
Shape
>
(
obj
,
[](
PyObject
*
item
)
->
Shape
{
const
auto
&
shape
=
PyUnpackLongSequence
<
int64_t
>
(
item
);
return
Shape
(
DimVector
(
shape
.
begin
(),
shape
.
end
()));
});
}
// Generator
bool
PyGeneratorCheck
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
py
::
isinstance
<
Generator
>
(
handle
);
}
std
::
shared_ptr
<
Generator
>
PyUnpackGenerator
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
py
::
cast
<
std
::
shared_ptr
<
one
::
Generator
>>
(
handle
);
}
// Device
bool
PyDeviceCheck
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
py
::
isinstance
<
Symbol
<
Device
>>
(
handle
);
}
Symbol
<
Device
>
PyUnpackDevice
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
*
py
::
cast
<
std
::
shared_ptr
<
Symbol
<
Device
>>>
(
handle
);
}
// Placement
bool
PyParallelDescCheck
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
py
::
isinstance
<
Symbol
<
ParallelDesc
>>
(
handle
);
}
Symbol
<
ParallelDesc
>
PyUnpackParallelDesc
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
*
py
::
cast
<
std
::
shared_ptr
<
Symbol
<
ParallelDesc
>>>
(
handle
);
}
// SBP
bool
PySbpParallelCheck
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
py
::
isinstance
<
Symbol
<
SbpParallel
>>
(
handle
);
}
Symbol
<
SbpParallel
>
PyUnpackSbpParallel
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
*
py
::
cast
<
std
::
shared_ptr
<
Symbol
<
SbpParallel
>>>
(
handle
);
}
// SBP list
bool
PySbpParallelSequenceCheck
(
PyObject
*
obj
)
{
return
PySequenceCheck
(
obj
,
[](
PyObject
*
item
)
{
return
PySbpParallelCheck
(
item
);
});
}
std
::
vector
<
Symbol
<
SbpParallel
>>
PyUnpackSbpParallelSequence
(
PyObject
*
obj
)
{
return
PyUnpackSequence
<
Symbol
<
SbpParallel
>>
(
obj
,
[](
PyObject
*
item
)
{
return
PyUnpackSbpParallel
(
item
);
});
}
// Tensor index
bool
PyTensorIndexCheck
(
PyObject
*
obj
)
{
return
PySlice_Check
(
obj
)
||
PyLong_Check
(
obj
)
||
obj
==
Py_Ellipsis
||
obj
==
Py_None
||
PyTensor_Check
(
obj
)
||
PySequence_Check
(
obj
)
||
PyUnicode_Check
(
obj
)
||
numpy
::
PyArrayCheckLongScalar
(
obj
);
}
TensorIndex
PyUnpackTensorIndex
(
PyObject
*
obj
)
{
TensorIndex
tensor_index
;
// Obvious single-entry cases.
if
(
PySlice_Check
(
obj
)
// NOLINT
||
PyLong_Check
(
obj
)
// NOLINT
||
obj
==
Py_Ellipsis
// NOLINT
||
obj
==
Py_None
// NOLINT
||
PyTensor_Check
(
obj
)
// NOLINT
||
!
PySequence_Check
(
obj
)
// NOLINT
||
numpy
::
PyArrayCheckLongScalar
(
obj
)
// NOLINT
||
PyUnicode_Check
(
obj
))
{
tensor_index
.
emplace_back
(
detail
::
UnpackIndexItem
(
obj
));
return
tensor_index
;
}
PyObject
*
tup
=
NULL
;
Py_ssize_t
n
=
0
;
if
(
PyTuple_Check
(
obj
))
{
tup
=
PySequence_Tuple
(
obj
);
n
=
PySequence_Size
(
tup
);
}
else
{
// The follow comments are from numpy:
// https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/mapping.c#L266
/*
* At this point, we're left with a non-tuple, non-array, sequence:
* typically, a list. We use some somewhat-arbitrary heuristics from here
* onwards to decided whether to treat that list as a single index, or a
* list of indices.
*/
n
=
PySequence_Size
(
obj
);
// Negative size indicates a Python error in the PySequence_Size call.
if
(
n
<
0
)
{
PyErr_Clear
();
tensor_index
.
emplace_back
(
detail
::
UnpackIndexItem
(
obj
));
return
tensor_index
;
}
// The follow comments are from numpy:
// https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/mapping.c#L280
/*
* Backwards compatibility only takes effect for short sequences - otherwise
* we treat it like any other scalar.
*
* Sequences < NPY_MAXDIMS with any slice objects
* or newaxis, Ellipsis or other arrays or sequences
* embedded, are considered equivalent to an indexing
* tuple. (`a[[[1,2], [3,4]]] == a[[1,2], [3,4]]`)
*/
if
(
n
>=
/*NPY_MAXDIMS=*/
32
)
{
tensor_index
.
emplace_back
(
detail
::
UnpackIndexItem
(
obj
));
return
tensor_index
;
}
// Check whether we should unpack the index like a tuple.
bool
commit_to_unpack
=
false
;
for
(
Py_ssize_t
i
=
0
;
i
<
n
;
++
i
)
{
PyObject
*
item
=
PySequence_GetItem
(
obj
,
i
);
if
(
commit_to_unpack
)
{
CHECK_OR_THROW
(
item
)
<<
"Sequence index is required."
;
}
else
{
if
(
!
item
)
{
PyErr_Clear
();
break
;
}
if
(
PySequence_Check
(
item
)
// NOLINT
||
PySlice_Check
(
item
)
// NOLINT
||
PyTensor_Check
(
item
)
// NOLINT
||
item
==
Py_Ellipsis
||
item
==
Py_None
)
{
commit_to_unpack
=
true
;
}
}
Py_DECREF
(
item
);
}
if
(
commit_to_unpack
)
{
tup
=
PySequence_Tuple
(
obj
);
}
else
{
tensor_index
.
emplace_back
(
detail
::
UnpackIndexItem
(
obj
));
return
tensor_index
;
}
}
tensor_index
.
resize
(
n
);
for
(
Py_ssize_t
i
=
0
;
i
<
n
;
++
i
)
{
PyObject
*
item
=
PySequence_GetItem
(
tup
,
i
);
tensor_index
[
i
]
=
detail
::
UnpackIndexItem
(
item
);
Py_DECREF
(
item
);
}
Py_DECREF
(
tup
);
return
tensor_index
;
}
// OpExpr
bool
PyOpExprCheck
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
py
::
isinstance
<
OpExpr
>
(
handle
);
}
std
::
shared_ptr
<
OpExpr
>
PyUnpackOpExpr
(
PyObject
*
obj
)
{
auto
handle
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
return
py
::
cast
<
std
::
shared_ptr
<
OpExpr
>>
(
handle
);
}
// int64_t
Maybe
<
int64_t
>
PyUnpackLong
(
PyObject
*
py_obj
)
{
int
overflow
=
-
1
;
long
long
val
=
PyLong_AsLongLongAndOverflow
(
py_obj
,
&
overflow
);
if
(
val
==
-
1
&&
PyErr_Occurred
())
{
return
Error
::
RuntimeError
()
<<
"Python exception occurs"
;
}
if
(
overflow
!=
0
)
{
return
Error
::
RuntimeError
()
<<
"Overflow when unpacking long"
;
}
return
(
int64_t
)
val
;
}
}
// namespace functional
}
// namespace one
}
// namespace oneflow
oneflow/api/python/functional/common.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_COMMON_H_
#define ONEFLOW_API_PYTHON_FUNCTIONAL_COMMON_H_
#include <string>
#include <vector>
#include <pybind11/pybind11.h>
#include "oneflow/api/python/framework/tensor.h"
#include "oneflow/core/common/throw.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/functional/tensor_index.h"
#include "oneflow/core/common/foreign_lock_helper.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
one
{
namespace
functional
{
struct
PyObjectPtrDeleter
{
inline
void
operator
()(
PyObject
*
obj
)
{
CHECK_JUST
(
Singleton
<
ForeignLockHelper
>::
Get
()
->
WithScopedAcquire
([
&
]()
->
Maybe
<
void
>
{
if
(
obj
)
{
Py_DECREF
(
obj
);
}
obj
=
NULL
;
return
Maybe
<
void
>::
Ok
();
}));
}
};
using
PyObjectPtr
=
std
::
unique_ptr
<
PyObject
,
PyObjectPtrDeleter
>
;
#define INTEGER_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int32_t) \
OF_PP_MAKE_TUPLE_SEQ(uint32_t) \
OF_PP_MAKE_TUPLE_SEQ(int64_t) \
OF_PP_MAKE_TUPLE_SEQ(uint64_t) \
OF_PP_MAKE_TUPLE_SEQ(bool)
#define FLOATING_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(float) \
OF_PP_MAKE_TUPLE_SEQ(double)
bool
PySequenceCheck
(
PyObject
*
obj
);
bool
PySequenceCheck
(
PyObject
*
obj
,
const
std
::
function
<
bool
(
PyObject
*
)
>&
item_check
);
template
<
typename
T
,
typename
UnpackItemFunc
>
inline
std
::
vector
<
T
>
PyUnpackSequence
(
PyObject
*
obj
,
UnpackItemFunc
unpack_item
)
{
bool
is_tuple
=
PyTuple_Check
(
obj
);
CHECK_OR_THROW
(
is_tuple
||
PyList_Check
(
obj
))
<<
"The object is not list or tuple, but is "
<<
Py_TYPE
(
obj
)
->
tp_name
;
size_t
size
=
is_tuple
?
PyTuple_GET_SIZE
(
obj
)
:
PyList_GET_SIZE
(
obj
);
std
::
vector
<
T
>
values
(
size
);
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
PyObject
*
item
=
is_tuple
?
PyTuple_GET_ITEM
(
obj
,
i
)
:
PyList_GET_ITEM
(
obj
,
i
);
values
[
i
]
=
unpack_item
(
item
);
}
return
values
;
}
// Integer/Float list
bool
PyLongSequenceCheck
(
PyObject
*
obj
);
bool
PyFloatSquenceCheck
(
PyObject
*
obj
);
template
<
typename
T
>
inline
std
::
vector
<
T
>
PyUnpackLongSequence
(
PyObject
*
obj
)
{
return
PyUnpackSequence
<
T
>
(
obj
,
[](
PyObject
*
item
)
->
T
{
return
static_cast
<
T
>
(
PyLong_AsLongLong
(
item
));
});
}
template
<
typename
T
>
inline
std
::
vector
<
T
>
PyUnpackFloatSequence
(
PyObject
*
obj
)
{
return
PyUnpackSequence
<
T
>
(
obj
,
[](
PyObject
*
item
)
->
T
{
return
static_cast
<
T
>
(
PyFloat_AsDouble
(
item
));
});
}
// String
bool
PyStringCheck
(
PyObject
*
obj
);
bool
PyStringSequenceCheck
(
PyObject
*
obj
);
std
::
string
PyStringAsString
(
PyObject
*
obj
);
std
::
string
PyObjectToReprStr
(
PyObject
*
obj
);
// Scalar
bool
PyScalarCheck
(
PyObject
*
obj
);
Scalar
PyUnpackScalar
(
PyObject
*
obj
);
// Tensor list
bool
PyTensorSequenceCheck
(
PyObject
*
obj
);
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
PyUnpackTensorSequence
(
PyObject
*
obj
);
// TensorTuple
bool
PyTensorTupleCheck
(
PyObject
*
obj
);
std
::
shared_ptr
<
TensorTuple
>
PyUnpackTensorTuple
(
PyObject
*
obj
);
// DType
bool
PyDTypeCheck
(
PyObject
*
obj
);
Symbol
<
DType
>
PyUnpackDType
(
PyObject
*
obj
);
// DType list
bool
PyDTypeSequenceCheck
(
PyObject
*
obj
);
std
::
vector
<
Symbol
<
DType
>>
PyUnpackDTypeSequence
(
PyObject
*
obj
);
// Shape list
bool
PyShapeSequenceCheck
(
PyObject
*
obj
);
std
::
vector
<
Shape
>
PyUnpackShapeSequence
(
PyObject
*
obj
);
// Generator
bool
PyGeneratorCheck
(
PyObject
*
obj
);
std
::
shared_ptr
<
Generator
>
PyUnpackGenerator
(
PyObject
*
obj
);
// Device
bool
PyDeviceCheck
(
PyObject
*
obj
);
Symbol
<
Device
>
PyUnpackDevice
(
PyObject
*
obj
);
// Placement
bool
PyParallelDescCheck
(
PyObject
*
obj
);
Symbol
<
ParallelDesc
>
PyUnpackParallelDesc
(
PyObject
*
obj
);
// SBP
bool
PySbpParallelCheck
(
PyObject
*
obj
);
Symbol
<
SbpParallel
>
PyUnpackSbpParallel
(
PyObject
*
obj
);
// SBP list
bool
PySbpParallelSequenceCheck
(
PyObject
*
obj
);
std
::
vector
<
Symbol
<
SbpParallel
>>
PyUnpackSbpParallelSequence
(
PyObject
*
obj
);
// Tensor index
bool
PyTensorIndexCheck
(
PyObject
*
obj
);
TensorIndex
PyUnpackTensorIndex
(
PyObject
*
obj
);
// OpExpr
bool
PyOpExprCheck
(
PyObject
*
obj
);
std
::
shared_ptr
<
OpExpr
>
PyUnpackOpExpr
(
PyObject
*
obj
);
template
<
typename
T
>
inline
PyObject
*
CastToPyObject
(
T
&&
t
)
{
return
py
::
cast
(
t
).
inc_ref
().
ptr
();
}
template
<
>
inline
PyObject
*
CastToPyObject
<
Maybe
<
Tensor
>>
(
Maybe
<
Tensor
>&&
t
)
{
return
PyTensor_New
(
t
.
GetPtrOrThrow
());
}
template
<
>
inline
PyObject
*
CastToPyObject
<
Maybe
<
TensorTuple
>>
(
Maybe
<
TensorTuple
>&&
t
)
{
const
auto
&
tensor_tuple
=
t
.
GetPtrOrThrow
();
py
::
tuple
tup
(
tensor_tuple
->
size
());
for
(
int
i
=
0
;
i
<
tensor_tuple
->
size
();
++
i
)
{
tup
[
i
]
=
py
::
cast
(
tensor_tuple
->
at
(
i
));
}
return
py
::
cast
<
py
::
object
>
(
tup
).
inc_ref
().
ptr
();
}
template
<
>
inline
PyObject
*
CastToPyObject
<
Maybe
<
void
>>
(
Maybe
<
void
>&&
t
)
{
t
.
GetOrThrow
();
Py_RETURN_NONE
;
}
// int64_t
Maybe
<
int64_t
>
PyUnpackLong
(
PyObject
*
py_obj
);
}
// namespace functional
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_FUNCTIONAL_COMMON_H_
oneflow/api/python/functional/dispatch_stateful_ops.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/function_library.h"
namespace
oneflow
{
namespace
one
{
namespace
functional
{
namespace
impl
{
ONEFLOW_FUNCTION_LIBRARY
(
m
)
{
m
.
add_functor
(
"DispatchFeedInput"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
)
->
Maybe
<
Tensor
>
{
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
});
});
m
.
add_functor
(
"DispatchFetchOutput"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
)
->
Maybe
<
Tensor
>
{
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
});
});
m
.
add_functor
(
"DispatchFeedVariable"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
const
Scalar
&
l2
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
<
double
>
(
"l2"
,
l2
.
As
<
double
>
()));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchOfrecordReader"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
string
&
data_dir
,
int32_t
data_part_num
,
const
std
::
string
&
part_name_prefix
,
int32_t
part_name_suffix_length
,
int32_t
batch_size
,
int32_t
shuffle_buffer_size
,
bool
random_shuffle
,
bool
shuffle_after_epoch
,
int64_t
seed
,
const
Optional
<
Symbol
<
Device
>>&
device
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"data_dir"
,
data_dir
));
JUST
(
attrs
.
SetAttr
(
"data_part_num"
,
data_part_num
));
JUST
(
attrs
.
SetAttr
(
"part_name_prefix"
,
part_name_prefix
));
JUST
(
attrs
.
SetAttr
(
"part_name_suffix_length"
,
part_name_suffix_length
));
JUST
(
attrs
.
SetAttr
(
"batch_size"
,
batch_size
));
JUST
(
attrs
.
SetAttr
(
"shuffle_buffer_size"
,
shuffle_buffer_size
));
JUST
(
attrs
.
SetAttr
(
"random_shuffle"
,
random_shuffle
));
JUST
(
attrs
.
SetAttr
(
"shuffle_after_epoch"
,
shuffle_after_epoch
));
JUST
(
attrs
.
SetAttr
(
"seed"
,
seed
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{},
OpExprInterpContext
(
attrs
,
JUST
(
device
)));
});
m
.
add_functor
(
"DispatchOfrecordReader"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
string
&
data_dir
,
int32_t
data_part_num
,
const
std
::
string
&
part_name_prefix
,
int32_t
part_name_suffix_length
,
int32_t
batch_size
,
int32_t
shuffle_buffer_size
,
bool
random_shuffle
,
bool
shuffle_after_epoch
,
int64_t
seed
,
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"data_dir"
,
data_dir
));
JUST
(
attrs
.
SetAttr
(
"data_part_num"
,
data_part_num
));
JUST
(
attrs
.
SetAttr
(
"part_name_prefix"
,
part_name_prefix
));
JUST
(
attrs
.
SetAttr
(
"part_name_suffix_length"
,
part_name_suffix_length
));
JUST
(
attrs
.
SetAttr
(
"batch_size"
,
batch_size
));
JUST
(
attrs
.
SetAttr
(
"shuffle_buffer_size"
,
shuffle_buffer_size
));
JUST
(
attrs
.
SetAttr
(
"random_shuffle"
,
random_shuffle
));
JUST
(
attrs
.
SetAttr
(
"shuffle_after_epoch"
,
shuffle_after_epoch
));
JUST
(
attrs
.
SetAttr
(
"seed"
,
seed
));
JUST
(
attrs
.
SetAttr
(
"nd_sbp"
,
*
JUST
(
GetNdSbpStrList
(
sbp_tuple
))));
auto
nd_sbp
=
JUST
(
GetNdSbp
(
sbp_tuple
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{},
OpExprInterpContext
(
attrs
,
placement
,
nd_sbp
));
});
m
.
add_functor
(
"DispatchOfrecordRawDecoder"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
const
std
::
string
&
name
,
const
Shape
&
shape
,
const
Symbol
<
DType
>&
data_type
,
bool
dim1_varying_length
,
bool
truncate
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"name"
,
name
));
JUST
(
attrs
.
SetAttr
(
"shape"
,
shape
));
JUST
(
attrs
.
SetAttr
(
"data_type"
,
data_type
->
data_type
()));
JUST
(
attrs
.
SetAttr
(
"dim1_varying_length"
,
dim1_varying_length
));
JUST
(
attrs
.
SetAttr
(
"truncate"
,
truncate
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchCoinFlip"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
int64_t
batch_size
,
Scalar
probability
,
int64_t
seed
,
bool
has_seed
,
const
Optional
<
Symbol
<
Device
>>&
device
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"probability"
,
probability
.
As
<
float
>
()));
JUST
(
attrs
.
SetAttr
(
"batch_size"
,
batch_size
));
JUST
(
attrs
.
SetAttr
(
"seed"
,
seed
));
JUST
(
attrs
.
SetAttr
(
"has_seed"
,
has_seed
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{},
OpExprInterpContext
(
attrs
,
JUST
(
device
)));
});
m
.
add_functor
(
"DispatchCoinFlip"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
int64_t
batch_size
,
Scalar
probability
,
int64_t
seed
,
bool
has_seed
,
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"probability"
,
probability
.
As
<
float
>
()));
JUST
(
attrs
.
SetAttr
(
"batch_size"
,
batch_size
));
JUST
(
attrs
.
SetAttr
(
"seed"
,
seed
));
JUST
(
attrs
.
SetAttr
(
"has_seed"
,
has_seed
));
JUST
(
attrs
.
SetAttr
(
"nd_sbp"
,
*
JUST
(
GetNdSbpStrList
(
sbp_tuple
))));
auto
nd_sbp
=
JUST
(
GetNdSbp
(
sbp_tuple
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{},
OpExprInterpContext
(
attrs
,
placement
,
nd_sbp
));
});
m
.
add_functor
(
"DispatchDistributedPariticalFCSample"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
weight
,
const
std
::
shared_ptr
<
Tensor
>&
label
,
const
int64_t
&
num_sample
)
->
Maybe
<
TensorTuple
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
<
int64_t
>
(
"num_sample"
,
num_sample
));
return
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
{
weight
,
label
},
attrs
);
});
m
.
add_functor
(
"DispatchCropMirrorNormalizeFromUint8"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
TensorTuple
&
input
,
int64_t
crop_h
,
int64_t
crop_w
,
float
crop_pos_x
,
float
crop_pos_y
,
const
std
::
vector
<
float
>&
mean
,
const
std
::
vector
<
float
>&
std
,
const
Symbol
<
DType
>&
output_dtype
,
const
std
::
string
&
output_layout
,
const
std
::
string
&
color_space
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"color_space"
,
color_space
));
JUST
(
attrs
.
SetAttr
(
"output_layout"
,
output_layout
));
JUST
(
attrs
.
SetAttr
(
"mean"
,
mean
));
JUST
(
attrs
.
SetAttr
(
"std"
,
std
));
JUST
(
attrs
.
SetAttr
(
"crop_h"
,
crop_h
));
JUST
(
attrs
.
SetAttr
(
"crop_w"
,
crop_w
));
JUST
(
attrs
.
SetAttr
(
"crop_pos_x"
,
crop_pos_x
));
JUST
(
attrs
.
SetAttr
(
"crop_pos_y"
,
crop_pos_y
));
JUST
(
attrs
.
SetAttr
(
"output_dtype"
,
output_dtype
->
data_type
()));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
input
,
attrs
);
});
m
.
add_functor
(
"DispatchCropMirrorNormalizeFromTensorBuffer"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
TensorTuple
&
input
,
int64_t
crop_h
,
int64_t
crop_w
,
float
crop_pos_x
,
float
crop_pos_y
,
const
std
::
vector
<
float
>&
mean
,
const
std
::
vector
<
float
>&
std
,
const
Symbol
<
DType
>&
output_dtype
,
const
std
::
string
&
output_layout
,
const
std
::
string
&
color_space
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"color_space"
,
color_space
));
JUST
(
attrs
.
SetAttr
(
"output_layout"
,
output_layout
));
JUST
(
attrs
.
SetAttr
(
"mean"
,
mean
));
JUST
(
attrs
.
SetAttr
(
"std"
,
std
));
JUST
(
attrs
.
SetAttr
(
"crop_h"
,
crop_h
));
JUST
(
attrs
.
SetAttr
(
"crop_w"
,
crop_w
));
JUST
(
attrs
.
SetAttr
(
"crop_pos_x"
,
crop_pos_x
));
JUST
(
attrs
.
SetAttr
(
"crop_pos_y"
,
crop_pos_y
));
JUST
(
attrs
.
SetAttr
(
"output_dtype"
,
output_dtype
->
data_type
()));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchOfrecordImageDecoderRandomCrop"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
const
std
::
string
&
name
,
const
std
::
string
&
color_space
,
const
std
::
vector
<
float
>&
random_area
,
const
std
::
vector
<
float
>&
random_aspect_ratio
,
int32_t
num_attempts
,
int64_t
seed
,
bool
has_seed
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"name"
,
name
));
JUST
(
attrs
.
SetAttr
(
"color_space"
,
color_space
));
JUST
(
attrs
.
SetAttr
(
"num_attempts"
,
num_attempts
));
JUST
(
attrs
.
SetAttr
(
"seed"
,
seed
));
JUST
(
attrs
.
SetAttr
(
"has_seed"
,
has_seed
));
JUST
(
attrs
.
SetAttr
(
"random_area"
,
random_area
));
JUST
(
attrs
.
SetAttr
(
"random_aspect_ratio"
,
random_aspect_ratio
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchOfrecordImageDecoder"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
const
std
::
string
&
name
,
const
std
::
string
&
color_space
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"name"
,
name
));
JUST
(
attrs
.
SetAttr
(
"color_space"
,
color_space
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchImageDecoderRandomCropResize"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
int64_t
target_width
,
int64_t
target_height
,
int64_t
seed
,
int64_t
num_workers
,
int64_t
max_num_pixels
,
float
random_area_min
,
float
random_area_max
,
float
random_aspect_ratio_min
,
float
random_aspect_ratio_max
,
int64_t
warmup_size
,
int64_t
num_attempts
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"target_width"
,
target_width
));
JUST
(
attrs
.
SetAttr
(
"target_height"
,
target_height
));
JUST
(
attrs
.
SetAttr
(
"seed"
,
seed
));
JUST
(
attrs
.
SetAttr
(
"num_workers"
,
num_workers
));
JUST
(
attrs
.
SetAttr
(
"max_num_pixels"
,
max_num_pixels
));
JUST
(
attrs
.
SetAttr
(
"random_area_min"
,
random_area_min
));
JUST
(
attrs
.
SetAttr
(
"random_area_max"
,
random_area_max
));
JUST
(
attrs
.
SetAttr
(
"random_aspect_ratio_min"
,
random_aspect_ratio_min
));
JUST
(
attrs
.
SetAttr
(
"random_aspect_ratio_max"
,
random_aspect_ratio_max
));
JUST
(
attrs
.
SetAttr
(
"warmup_size"
,
warmup_size
));
JUST
(
attrs
.
SetAttr
(
"num_attempts"
,
num_attempts
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchTensorBufferToListOfTensorsV2"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
const
std
::
vector
<
Shape
>&
out_shapes
,
const
std
::
vector
<
Symbol
<
DType
>>&
out_dtypes
,
bool
dynamic_out
)
->
Maybe
<
TensorTuple
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"out_shapes"
,
out_shapes
));
JUST
(
attrs
.
SetAttr
(
"dynamic_out"
,
dynamic_out
));
auto
out_data_types
=
std
::
vector
<
DataType
>
();
for
(
auto
it
=
out_dtypes
.
begin
();
it
!=
out_dtypes
.
end
();
it
++
)
{
out_data_types
.
emplace_back
((
*
it
)
->
data_type
());
}
JUST
(
attrs
.
SetAttr
(
"out_dtypes"
,
out_data_types
));
return
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchImageResizeKeepAspectRatio"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
int32_t
target_size
,
int32_t
min_size
,
int32_t
max_size
,
bool
resize_longer
,
const
std
::
string
&
interpolation_type
)
->
Maybe
<
TensorTuple
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"target_size"
,
target_size
));
JUST
(
attrs
.
SetAttr
(
"min_size"
,
min_size
));
JUST
(
attrs
.
SetAttr
(
"max_size"
,
max_size
));
JUST
(
attrs
.
SetAttr
(
"resize_longer"
,
resize_longer
));
JUST
(
attrs
.
SetAttr
(
"interpolation_type"
,
interpolation_type
));
return
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchImageResizeToFixed"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
int64_t
target_width
,
int64_t
target_height
,
int64_t
channels
,
const
Symbol
<
DType
>&
data_type
,
const
std
::
string
&
interpolation_type
)
->
Maybe
<
TensorTuple
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"target_width"
,
target_width
));
JUST
(
attrs
.
SetAttr
(
"target_height"
,
target_height
));
JUST
(
attrs
.
SetAttr
(
"channels"
,
channels
));
JUST
(
attrs
.
SetAttr
(
"data_type"
,
data_type
->
data_type
()));
JUST
(
attrs
.
SetAttr
(
"interpolation_type"
,
interpolation_type
));
return
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchImageDecode"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
const
std
::
string
&
color_space
,
const
Symbol
<
DType
>&
data_type
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"color_space"
,
color_space
));
JUST
(
attrs
.
SetAttr
(
"data_type"
,
data_type
->
data_type
()));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchImageNormalize"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
const
std
::
vector
<
float
>&
mean
,
const
std
::
vector
<
float
>&
std
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"std"
,
std
));
JUST
(
attrs
.
SetAttr
(
"mean"
,
mean
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchCOCOReader"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
string
&
image_dir
,
const
std
::
string
&
annotation_file
,
int64_t
batch_size
,
bool
shuffle_after_epoch
,
int64_t
random_seed
,
bool
group_by_ratio
,
bool
remove_images_without_annotations
,
bool
stride_partition
,
int64_t
session_id
,
const
Optional
<
Symbol
<
Device
>>&
device
)
->
Maybe
<
TensorTuple
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"session_id"
,
session_id
));
JUST
(
attrs
.
SetAttr
(
"annotation_file"
,
annotation_file
));
JUST
(
attrs
.
SetAttr
(
"image_dir"
,
image_dir
));
JUST
(
attrs
.
SetAttr
(
"batch_size"
,
batch_size
));
JUST
(
attrs
.
SetAttr
(
"shuffle_after_epoch"
,
shuffle_after_epoch
));
JUST
(
attrs
.
SetAttr
(
"random_seed"
,
random_seed
));
JUST
(
attrs
.
SetAttr
(
"group_by_ratio"
,
group_by_ratio
));
JUST
(
attrs
.
SetAttr
(
"remove_images_without_annotations"
,
remove_images_without_annotations
));
JUST
(
attrs
.
SetAttr
(
"stride_partition"
,
stride_partition
));
return
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
{},
OpExprInterpContext
(
attrs
,
JUST
(
device
)));
});
m
.
add_functor
(
"DispatchCOCOReader"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
string
&
image_dir
,
const
std
::
string
&
annotation_file
,
int64_t
batch_size
,
bool
shuffle_after_epoch
,
int64_t
random_seed
,
bool
group_by_ratio
,
bool
remove_images_without_annotations
,
bool
stride_partition
,
int64_t
session_id
,
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
)
->
Maybe
<
TensorTuple
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"session_id"
,
session_id
));
JUST
(
attrs
.
SetAttr
(
"annotation_file"
,
annotation_file
));
JUST
(
attrs
.
SetAttr
(
"image_dir"
,
image_dir
));
JUST
(
attrs
.
SetAttr
(
"batch_size"
,
batch_size
));
JUST
(
attrs
.
SetAttr
(
"shuffle_after_epoch"
,
shuffle_after_epoch
));
JUST
(
attrs
.
SetAttr
(
"random_seed"
,
random_seed
));
JUST
(
attrs
.
SetAttr
(
"group_by_ratio"
,
group_by_ratio
));
JUST
(
attrs
.
SetAttr
(
"remove_images_without_annotations"
,
remove_images_without_annotations
));
JUST
(
attrs
.
SetAttr
(
"stride_partition"
,
stride_partition
));
JUST
(
attrs
.
SetAttr
(
"nd_sbp"
,
*
JUST
(
GetNdSbpStrList
(
sbp_tuple
))));
auto
nd_sbp
=
JUST
(
GetNdSbp
(
sbp_tuple
));
return
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
{},
OpExprInterpContext
(
attrs
,
placement
,
nd_sbp
));
});
m
.
add_functor
(
"DispatchImageBatchAlign"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
int32_t
alignment
,
const
Shape
&
shape
,
const
Symbol
<
DType
>&
data_type
,
bool
dynamic_out
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"shape"
,
shape
));
JUST
(
attrs
.
SetAttr
(
"data_type"
,
data_type
->
data_type
()));
JUST
(
attrs
.
SetAttr
(
"alignment"
,
alignment
));
JUST
(
attrs
.
SetAttr
(
"dynamic_out"
,
dynamic_out
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchOfrecordBytesDecoder"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
const
std
::
string
&
name
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"name"
,
name
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
m
.
add_functor
(
"DispatchOneRecReader"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
vector
<
std
::
string
>&
files
,
const
int64_t
batch_size
,
const
bool
random_shuffle
,
const
std
::
string
&
shuffle_mode
,
const
int32_t
shuffle_buffer_size
,
const
bool
shuffle_after_epoch
,
int64_t
random_seed
,
const
bool
verify_example
,
const
Optional
<
Symbol
<
Device
>>&
device
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"files"
,
files
));
JUST
(
attrs
.
SetAttr
<
int64_t
>
(
"batch_size"
,
batch_size
));
JUST
(
attrs
.
SetAttr
<
bool
>
(
"random_shuffle"
,
random_shuffle
));
JUST
(
attrs
.
SetAttr
<
std
::
string
>
(
"shuffle_mode"
,
shuffle_mode
));
JUST
(
attrs
.
SetAttr
<
int32_t
>
(
"shuffle_buffer_size"
,
shuffle_buffer_size
));
JUST
(
attrs
.
SetAttr
<
bool
>
(
"shuffle_after_epoch"
,
shuffle_after_epoch
));
JUST
(
attrs
.
SetAttr
<
int64_t
>
(
"seed"
,
random_seed
));
JUST
(
attrs
.
SetAttr
<
bool
>
(
"verify_example"
,
verify_example
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{},
OpExprInterpContext
(
attrs
,
JUST
(
device
)));
});
m
.
add_functor
(
"DispatchOneRecReader"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
vector
<
std
::
string
>&
files
,
const
int64_t
batch_size
,
const
bool
random_shuffle
,
const
std
::
string
&
shuffle_mode
,
const
int32_t
shuffle_buffer_size
,
const
bool
shuffle_after_epoch
,
int64_t
random_seed
,
const
bool
verify_example
,
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"files"
,
files
));
JUST
(
attrs
.
SetAttr
<
int64_t
>
(
"batch_size"
,
batch_size
));
JUST
(
attrs
.
SetAttr
<
bool
>
(
"random_shuffle"
,
random_shuffle
));
JUST
(
attrs
.
SetAttr
<
std
::
string
>
(
"shuffle_mode"
,
shuffle_mode
));
JUST
(
attrs
.
SetAttr
<
int32_t
>
(
"shuffle_buffer_size"
,
shuffle_buffer_size
));
JUST
(
attrs
.
SetAttr
<
bool
>
(
"shuffle_after_epoch"
,
shuffle_after_epoch
));
JUST
(
attrs
.
SetAttr
<
int64_t
>
(
"seed"
,
random_seed
));
JUST
(
attrs
.
SetAttr
<
bool
>
(
"verify_example"
,
verify_example
));
JUST
(
attrs
.
SetAttr
(
"nd_sbp"
,
*
JUST
(
GetNdSbpStrList
(
sbp_tuple
))));
auto
nd_sbp
=
JUST
(
GetNdSbp
(
sbp_tuple
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{},
OpExprInterpContext
(
attrs
,
placement
,
nd_sbp
));
});
m
.
add_functor
(
"DispatchMegatronGptMmapDataLoader"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
string
&
data_file_prefix
,
int64_t
seq_length
,
int64_t
label_length
,
int64_t
num_samples
,
int64_t
batch_size
,
const
Symbol
<
DType
>&
dtype
,
const
std
::
vector
<
int64_t
>&
split_sizes
,
int64_t
split_index
,
bool
shuffle
,
int64_t
random_seed
,
const
Optional
<
Symbol
<
Device
>>&
device
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"data_file_prefix"
,
data_file_prefix
));
JUST
(
attrs
.
SetAttr
(
"seq_length"
,
seq_length
));
JUST
(
attrs
.
SetAttr
(
"label_length"
,
label_length
));
JUST
(
attrs
.
SetAttr
(
"num_samples"
,
num_samples
));
JUST
(
attrs
.
SetAttr
(
"batch_size"
,
batch_size
));
JUST
(
attrs
.
SetAttr
(
"dtype"
,
dtype
->
data_type
()));
JUST
(
attrs
.
SetAttr
(
"split_sizes"
,
split_sizes
));
JUST
(
attrs
.
SetAttr
(
"split_index"
,
split_index
));
JUST
(
attrs
.
SetAttr
(
"shuffle"
,
shuffle
));
JUST
(
attrs
.
SetAttr
(
"random_seed"
,
random_seed
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{},
OpExprInterpContext
(
attrs
,
JUST
(
device
)));
});
m
.
add_functor
(
"DispatchMegatronGptMmapDataLoader"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
string
&
data_file_prefix
,
int64_t
seq_length
,
int64_t
label_length
,
int64_t
num_samples
,
int64_t
batch_size
,
const
Symbol
<
DType
>&
dtype
,
const
std
::
vector
<
int64_t
>&
split_sizes
,
int64_t
split_index
,
bool
shuffle
,
int64_t
random_seed
,
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"data_file_prefix"
,
data_file_prefix
));
JUST
(
attrs
.
SetAttr
(
"seq_length"
,
seq_length
));
JUST
(
attrs
.
SetAttr
(
"label_length"
,
label_length
));
JUST
(
attrs
.
SetAttr
(
"num_samples"
,
num_samples
));
JUST
(
attrs
.
SetAttr
(
"batch_size"
,
batch_size
));
JUST
(
attrs
.
SetAttr
(
"dtype"
,
dtype
->
data_type
()));
JUST
(
attrs
.
SetAttr
(
"split_sizes"
,
split_sizes
));
JUST
(
attrs
.
SetAttr
(
"split_index"
,
split_index
));
JUST
(
attrs
.
SetAttr
(
"shuffle"
,
shuffle
));
JUST
(
attrs
.
SetAttr
(
"random_seed"
,
random_seed
));
auto
nd_sbp
=
JUST
(
GetNdSbp
(
sbp_tuple
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{},
OpExprInterpContext
(
attrs
,
placement
,
nd_sbp
));
});
m
.
add_functor
(
"DispatchRmspropUpdate"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
TensorTuple
&
inputs
,
float
learning_rate
,
double
scale
,
float
l1
,
float
l2
,
bool
centered
,
float
epsilon
,
float
decay_rate
,
float
weight_decay
)
->
Maybe
<
void
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"learning_rate_val"
,
learning_rate
));
JUST
(
attrs
.
SetAttr
(
"scale"
,
scale
));
JUST
(
attrs
.
SetAttr
(
"l1"
,
l1
));
JUST
(
attrs
.
SetAttr
(
"l2"
,
l2
));
JUST
(
attrs
.
SetAttr
(
"centered"
,
centered
));
JUST
(
attrs
.
SetAttr
(
"epsilon"
,
epsilon
));
JUST
(
attrs
.
SetAttr
(
"decay_rate"
,
decay_rate
));
JUST
(
attrs
.
SetAttr
(
"weight_decay"
,
weight_decay
));
JUST
(
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
inputs
,
attrs
));
return
Maybe
<
void
>::
Ok
();
});
m
.
add_functor
(
"DispatchAdamUpdate"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
TensorTuple
&
inputs
,
float
learning_rate
,
float
bias_correction1
,
float
bias_correction2
,
double
scale
,
float
l1
,
float
l2
,
float
beta1
,
float
beta2
,
float
epsilon
,
float
weight_decay
,
bool
amsgrad
,
bool
do_bias_correction
)
->
Maybe
<
void
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"learning_rate_val"
,
learning_rate
));
JUST
(
attrs
.
SetAttr
(
"bias_correction1_val"
,
bias_correction1
));
JUST
(
attrs
.
SetAttr
(
"bias_correction2_val"
,
bias_correction2
));
JUST
(
attrs
.
SetAttr
(
"scale"
,
scale
));
JUST
(
attrs
.
SetAttr
(
"l1"
,
l1
));
JUST
(
attrs
.
SetAttr
(
"l2"
,
l2
));
JUST
(
attrs
.
SetAttr
(
"beta1"
,
beta1
));
JUST
(
attrs
.
SetAttr
(
"beta2"
,
beta2
));
JUST
(
attrs
.
SetAttr
(
"epsilon"
,
epsilon
));
JUST
(
attrs
.
SetAttr
(
"weight_decay"
,
weight_decay
));
JUST
(
attrs
.
SetAttr
(
"amsgrad"
,
amsgrad
));
JUST
(
attrs
.
SetAttr
(
"do_bias_correction"
,
do_bias_correction
));
JUST
(
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
inputs
,
attrs
));
return
Maybe
<
void
>::
Ok
();
});
m
.
add_functor
(
"DispatchAdagradUpdate"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
TensorTuple
&
inputs
,
float
learning_rate
,
double
scale
,
float
l1
,
float
l2
,
float
lr_decay
,
float
weight_decay
,
float
epsilon
,
int32_t
train_step
)
->
Maybe
<
void
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"learning_rate_val"
,
learning_rate
));
JUST
(
attrs
.
SetAttr
(
"scale"
,
scale
));
JUST
(
attrs
.
SetAttr
(
"l1"
,
l1
));
JUST
(
attrs
.
SetAttr
(
"l2"
,
l2
));
JUST
(
attrs
.
SetAttr
(
"lr_decay"
,
lr_decay
));
JUST
(
attrs
.
SetAttr
(
"weight_decay"
,
weight_decay
));
JUST
(
attrs
.
SetAttr
(
"epsilon"
,
epsilon
));
JUST
(
attrs
.
SetAttr
(
"train_step_val"
,
train_step
));
JUST
(
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
inputs
,
attrs
));
return
Maybe
<
void
>::
Ok
();
});
m
.
add_functor
(
"DispatchMomentumUpdate"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
TensorTuple
&
inputs
,
float
learning_rate
,
double
scale
,
float
l1
,
float
l2
,
float
beta
,
float
weight_decay
)
->
Maybe
<
void
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"learning_rate_val"
,
learning_rate
));
JUST
(
attrs
.
SetAttr
(
"scale"
,
scale
));
JUST
(
attrs
.
SetAttr
(
"l1"
,
l1
));
JUST
(
attrs
.
SetAttr
(
"l2"
,
l2
));
JUST
(
attrs
.
SetAttr
(
"beta"
,
beta
));
JUST
(
attrs
.
SetAttr
(
"weight_decay"
,
weight_decay
));
JUST
(
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
inputs
,
attrs
));
return
Maybe
<
void
>::
Ok
();
});
m
.
add_functor
(
"DispatchSgdUpdate"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
TensorTuple
&
inputs
,
float
learning_rate
,
double
scale
,
float
l1
,
float
l2
,
float
weight_decay
)
->
Maybe
<
void
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"learning_rate_val"
,
learning_rate
));
JUST
(
attrs
.
SetAttr
(
"scale"
,
scale
));
JUST
(
attrs
.
SetAttr
(
"l1"
,
l1
));
JUST
(
attrs
.
SetAttr
(
"l2"
,
l2
));
JUST
(
attrs
.
SetAttr
(
"weight_decay"
,
weight_decay
));
JUST
(
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
inputs
,
attrs
));
return
Maybe
<
void
>::
Ok
();
});
m
.
add_functor
(
"DispatchLambUpdate"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
TensorTuple
&
inputs
,
float
learning_rate
,
float
bias_correction1
,
float
bias_correction2
,
double
scale
,
float
l1
,
float
l2
,
float
beta1
,
float
beta2
,
float
epsilon
,
float
weight_decay
,
bool
do_bias_correction
)
->
Maybe
<
void
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"learning_rate_val"
,
learning_rate
));
JUST
(
attrs
.
SetAttr
(
"bias_correction1_val"
,
bias_correction1
));
JUST
(
attrs
.
SetAttr
(
"bias_correction2_val"
,
bias_correction2
));
JUST
(
attrs
.
SetAttr
(
"scale"
,
scale
));
JUST
(
attrs
.
SetAttr
(
"l1"
,
l1
));
JUST
(
attrs
.
SetAttr
(
"l2"
,
l2
));
JUST
(
attrs
.
SetAttr
(
"beta1"
,
beta1
));
JUST
(
attrs
.
SetAttr
(
"beta2"
,
beta2
));
JUST
(
attrs
.
SetAttr
(
"epsilon"
,
epsilon
));
JUST
(
attrs
.
SetAttr
(
"weight_decay"
,
weight_decay
));
JUST
(
attrs
.
SetAttr
(
"do_bias_correction"
,
do_bias_correction
));
JUST
(
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
inputs
,
attrs
));
return
Maybe
<
void
>::
Ok
();
});
m
.
add_functor
(
"DispatchFtrlUpdate"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
TensorTuple
&
inputs
,
float
learning_rate
,
double
scale
,
float
l1
,
float
l2
,
float
lr_power
,
float
lambda1
,
float
lambda2
,
float
beta
,
float
weight_decay
)
->
Maybe
<
void
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"learning_rate_val"
,
learning_rate
));
JUST
(
attrs
.
SetAttr
(
"scale"
,
scale
));
JUST
(
attrs
.
SetAttr
(
"l1"
,
l1
));
JUST
(
attrs
.
SetAttr
(
"l2"
,
l2
));
JUST
(
attrs
.
SetAttr
(
"lr_power"
,
lr_power
));
JUST
(
attrs
.
SetAttr
(
"lambda1"
,
lambda1
));
JUST
(
attrs
.
SetAttr
(
"lambda2"
,
lambda2
));
JUST
(
attrs
.
SetAttr
(
"beta"
,
beta
));
JUST
(
attrs
.
SetAttr
(
"weight_decay"
,
weight_decay
));
JUST
(
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op
,
inputs
,
attrs
));
return
Maybe
<
void
>::
Ok
();
});
m
.
add_functor
(
"DispatchEagerNcclAllReduce"
,
[](
const
std
::
shared_ptr
<
OpExpr
>&
op
,
const
std
::
shared_ptr
<
Tensor
>&
input
,
const
std
::
string
&
parallel_conf
,
bool
async_launch
)
->
Maybe
<
Tensor
>
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
(
"parallel_conf"
,
parallel_conf
));
JUST
(
attrs
.
SetAttr
(
"async_launch"
,
async_launch
));
return
OpInterpUtil
::
Dispatch
<
Tensor
>
(
*
op
,
{
input
},
attrs
);
});
}
}
// namespace impl
}
// namespace functional
}
// namespace one
}
// namespace oneflow
oneflow/api/python/functional/dispatch_stateful_ops.yaml
0 → 100644
View file @
21d47d0e
# Copyright 2020 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The following data types are allowed,
# {
# "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool",
# "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList",
# "BoolList", "DataType", "Shape", "Generator", "TensorIndex", "Device", "Placement",
# "Sbp", "SbpList"
# }
-
name
:
"
dispatch_feed_input"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input)
=>
DispatchFeedInput"
bind_python
:
True
-
name
:
"
dispatch_feed_variable"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input,
Scalar
l2)
=>
DispatchFeedVariable"
bind_python
:
True
-
name
:
"
dispatch_fetch_output"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input)
=>
DispatchFetchOutput"
bind_python
:
True
-
name
:
"
dispatch_ofrecord_reader"
signature
:
[
"
Tensor
(OpExpr
op,
String
data_dir,
Int32
data_part_num,
String
part_name_prefix=
\"
part-
\"
,
Int32
part_name_suffix_length=-1,
Int32
batch_size,
Int32
shuffle_buffer_size=1024,
Bool
random_shuffle=False,
Bool
shuffle_after_epoch=False,
Int64
seed=-1,
Device
device=None)
=>
DispatchOfrecordReader"
,
"
Tensor
(OpExpr
op,
String
data_dir,
Int32
data_part_num,
String
part_name_prefix=
\"
part-
\"
,
Int32
part_name_suffix_length=-1,
Int32
batch_size,
Int32
shuffle_buffer_size=1024,
Bool
random_shuffle=False,
Bool
shuffle_after_epoch=False,
Int64
seed=-1,
Placement
placement,
SbpList
sbp)
=>
DispatchOfrecordReader"
,
]
bind_python
:
True
-
name
:
"
dispatch_ofrecord_raw_decoder"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input,
String
name,
Shape
shape,
DataType
data_type,
Bool
dim1_varying_length=False,
Bool
truncate=False)
=>
DispatchOfrecordRawDecoder"
bind_python
:
True
-
name
:
"
dispatch_coin_flip"
signature
:
[
"
Tensor
(OpExpr
op,
Int64
batch_size,
Scalar
probability=0.5,
Int64
seed=-1,
Bool
has_seed=False,
Device
device=None)
=>
DispatchCoinFlip"
,
"
Tensor
(OpExpr
op,
Int64
batch_size,
Scalar
probability=0.5,
Int64
seed=-1,
Bool
has_seed=False,
Placement
placement,
SbpList
sbp)
=>
DispatchCoinFlip"
,
]
bind_python
:
True
-
name
:
"
dispatch_distributed_partial_fc_sample"
signature
:
"
TensorTuple
(OpExpr
op,
Tensor
weight,
Tensor
label,
Int64
num_sample)
=>
DispatchDistributedPariticalFCSample"
bind_python
:
True
-
name
:
"
dispatch_crop_mirror_normalize_from_uint8"
signature
:
"
Tensor
(OpExpr
op,
TensorTuple
input,
Int64
crop_h=0,
Int64
crop_w=0,
Float
crop_pos_x=0.5,
Float
crop_pos_y=0.5,
FloatList
mean,
FloatList
std,
DataType
output_dtype=kFloat,
String
output_layout=
\"
NCHW
\"
,
String
color_space=
\"
BGR
\"
)
=>
DispatchCropMirrorNormalizeFromUint8"
bind_python
:
True
-
name
:
"
dispatch_crop_mirror_normalize_from_tensorbuffer"
signature
:
"
Tensor
(OpExpr
op,
TensorTuple
input,
Int64
crop_h=0,
Int64
crop_w=0,
Float
crop_pos_x=0.5,
Float
crop_pos_y=0.5,
FloatList
mean,
FloatList
std,
DataType
output_dtype=kFloat,
String
output_layout=
\"
NCHW
\"
,
String
color_space=
\"
BGR
\"
)
=>
DispatchCropMirrorNormalizeFromTensorBuffer"
bind_python
:
True
-
name
:
"
dispatch_ofrecord_image_decoder_random_crop"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input,
String
name,
String
color_space=
\"
BGR
\"
,
FloatList
random_area,
FloatList
random_aspect_ratio,
Int32
num_attempts=10,
Int64
seed=-1,
Bool
has_seed=False)
=>
DispatchOfrecordImageDecoderRandomCrop"
bind_python
:
True
-
name
:
"
dispatch_ofrecord_image_decoder"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input,
String
name,
String
color_space=
\"
BGR
\"
)
=>
DispatchOfrecordImageDecoder"
bind_python
:
True
-
name
:
"
dispatch_image_decoder_random_crop_resize"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input,
Int64
target_width,
Int64
target_height,
Int64
seed,
Int64
num_workers=3,
Int64
max_num_pixels=67108864,
Float
random_area_min=0.08f,
Float
random_area_max=1.0f,
Float
random_aspect_ratio_min=0.75f,
Float
random_aspect_ratio_max=1.333333f,
Int64
warmup_size=6400,
Int64
num_attempts=10)
=>
DispatchImageDecoderRandomCropResize"
bind_python
:
True
-
name
:
"
dispatch_tensor_buffer_to_list_of_tensors_v2"
signature
:
"
TensorTuple
(OpExpr
op,
Tensor
input,
ShapeList
out_shapes,
DataTypeList
out_dtypes,
Bool
dynamic_out)
=>
DispatchTensorBufferToListOfTensorsV2"
bind_python
:
True
-
name
:
"
dispatch_image_resize_keep_aspect_ratio"
signature
:
"
TensorTuple
(OpExpr
op,
Tensor
input,
Int32
target_size,
Int32
min_size=0,
Int32
max_size=0,
Bool
resize_longer=False,
String
interpolation_type=
\"
bilinear
\"
)
=>
DispatchImageResizeKeepAspectRatio"
bind_python
:
True
-
name
:
"
dispatch_image_resize_to_fixed"
signature
:
"
TensorTuple
(OpExpr
op,
Tensor
input,
Int64
target_width=0,
Int64
target_height=0,
Int64
channels=3,
DataType
data_type=kUInt8,
String
interpolation_type=
\"
bilinear
\"
)
=>
DispatchImageResizeToFixed"
bind_python
:
True
-
name
:
"
dispatch_image_decode"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input,
String
color_space=
\"
BGR
\"
,
DataType
data_type=kUInt8)
=>
DispatchImageDecode"
bind_python
:
True
-
name
:
"
dispatch_image_normalize"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input,
FloatList
mean,
FloatList
std)
=>
DispatchImageNormalize"
bind_python
:
True
-
name
:
"
dispatch_coco_reader"
signature
:
[
"
TensorTuple
(OpExpr
op,
String
image_dir,
String
annotation_file,
Int64
batch_size,
Bool
shuffle_after_epoch=False,
Int64
random_seed=-1,
Bool
group_by_ratio=True,
Bool
remove_images_without_annotations=True,
Bool
stride_partition=False,
Int64
session_id,
Device
device=None)
=>
DispatchCOCOReader"
,
"
TensorTuple
(OpExpr
op,
String
image_dir,
String
annotation_file,
Int64
batch_size,
Bool
shuffle_after_epoch=False,
Int64
random_seed=-1,
Bool
group_by_ratio=True,
Bool
remove_images_without_annotations=True,
Bool
stride_partition=False,
Int64
session_id,
Placement
placement,
SbpList
sbp)
=>
DispatchCOCOReader"
,
]
bind_python
:
True
-
name
:
"
dispatch_image_batch_align"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input,
Int32
alignment,
Shape
shape,
DataType
data_type,
Bool
dynamic_out)
=>
DispatchImageBatchAlign"
bind_python
:
True
-
name
:
"
dispatch_ofrecord_bytes_decoder"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input,
String
name)
=>
DispatchOfrecordBytesDecoder"
bind_python
:
True
-
name
:
"
dispatch_onerec_reader"
signature
:
[
"
Tensor
(OpExpr
op,
StringList
files,
Int64
batch_size,
Bool
random_shuffle,
String
shuffle_mode,
Int32
shuffle_buffer_size=1024,
Bool
shuffle_after_epoch=False,
Int64
random_seed=-1,
Bool
verify_example=True,
Device
device=None)
=>
DispatchOneRecReader"
,
"
Tensor
(OpExpr
op,
StringList
files,
Int64
batch_size,
Bool
random_shuffle,
String
shuffle_mode,
Int32
shuffle_buffer_size=1024,
Bool
shuffle_after_epoch=False,
Int64
random_seed=-1,
Bool
verify_example=True,
Placement
placement,
SbpList
sbp)
=>
DispatchOneRecReader"
,
]
bind_python
:
True
-
name
:
"
dispatch_megatron_gpt_mmap_data_loader"
signature
:
[
"
Tensor
(OpExpr
op,
String
data_file_prefix,
Int64
seq_length,
Int64
label_length=1,
Int64
num_samples,
Int64
batch_size,
DataType
dtype,
Int64List
split_sizes,
Int64
split_index,
Bool
shuffle,
Int64
random_seed,
Device
device=None)
=>
DispatchMegatronGptMmapDataLoader"
,
"
Tensor
(OpExpr
op,
String
data_file_prefix,
Int64
seq_length,
Int64
label_length=1,
Int64
num_samples,
Int64
batch_size,
DataType
dtype,
Int64List
split_sizes,
Int64
split_index,
Bool
shuffle,
Int64
random_seed,
Placement
placement,
SbpList
sbp)
=>
DispatchMegatronGptMmapDataLoader"
,
]
bind_python
:
True
-
name
:
"
dispatch_rmsprop_update"
signature
:
"
Void
(OpExpr
op,
TensorTuple
inputs,
Float
learning_rate=0,
Double
scale=1.0,
Float
l1=0,
Float
l2=0,
Bool
centered=False,
Float
epsilon=1e-8,
Float
decay_rate=0.99,
Float
weight_decay=0.0)
=>
DispatchRmspropUpdate"
bind_python
:
True
-
name
:
"
dispatch_adam_update"
signature
:
"
Void
(OpExpr
op,
TensorTuple
inputs,
Float
learning_rate=0,
Float
bias_correction1=1.0,
Float
bias_correction2=1.0,
Double
scale=1.0,
Float
l1=0,
Float
l2=0,
Float
beta1=0.9,
Float
beta2=0.999,
Float
epsilon=1e-8,
Float
weight_decay=0,
Bool
amsgrad=False,
Bool
do_bias_correction=True)
=>
DispatchAdamUpdate"
bind_python
:
True
-
name
:
"
dispatch_adagrad_update"
signature
:
"
Void
(OpExpr
op,
TensorTuple
inputs,
Float
learning_rate=0,
Double
scale=1.0,
Float
l1=0,
Float
l2=0,
Float
lr_decay=0,
Float
weight_decay=0,
Float
epsilon=1e-10,
Int32
train_step_val=0)
=>
DispatchAdagradUpdate"
bind_python
:
True
-
name
:
"
dispatch_momentum_update"
signature
:
"
Void
(OpExpr
op,
TensorTuple
inputs,
Float
learning_rate=0,
Double
scale=1.0,
Float
l1=0,
Float
l2=0,
Float
beta=0.9,
Float
weight_decay=0)
=>
DispatchMomentumUpdate"
bind_python
:
True
-
name
:
"
dispatch_sgd_update"
signature
:
"
Void
(OpExpr
op,
TensorTuple
inputs,
Float
learning_rate=0,
Double
scale=1.0,
Float
l1=0,
Float
l2=0,
Float
weight_decay=0)
=>
DispatchSgdUpdate"
bind_python
:
True
-
name
:
"
dispatch_lamb_update"
signature
:
"
Void
(OpExpr
op,
TensorTuple
inputs,
Float
learning_rate=0,
Float
bias_correction1=1.0,
Float
bias_correction2=1.0,
Double
scale=1.0,
Float
l1=0,
Float
l2=0,
Float
beta1=0.9,
Float
beta2=0.999,
Float
epsilon=1e-8,
Float
weight_decay=0,
Bool
do_bias_correction=True)
=>
DispatchLambUpdate"
bind_python
:
True
-
name
:
"
dispatch_ftrl_update"
signature
:
"
Void
(OpExpr
op,
TensorTuple
inputs,
Float
learning_rate=0,
Double
scale=1.0,
Float
l1=0,
Float
l2=0,
Float
lr_power,
Float
lambda1,
Float
lambda2,
Float
beta,
Float
weight_decay=0)
=>
DispatchFtrlUpdate"
bind_python
:
True
-
name
:
"
dispatch_eager_nccl_all_reduce"
signature
:
"
Tensor
(OpExpr
op,
Tensor
input,
String
parallel_conf,
Bool
async_launch=False)
=>
DispatchEagerNcclAllReduce"
bind_python
:
True
oneflow/api/python/functional/function_def.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_FUNCTION_DEF_H_
#define ONEFLOW_API_PYTHON_FUNCTIONAL_FUNCTION_DEF_H_
#include <memory>
#include <string>
#include <vector>
#include "oneflow/api/python/functional/python_arg.h"
#include "oneflow/api/python/functional/value_types.h"
namespace
oneflow
{
namespace
one
{
namespace
functional
{
struct
ReturnDef
{
explicit
ReturnDef
(
const
ValueType
&
t
)
:
type
(
t
)
{}
ValueType
type
;
};
struct
ArgumentDef
{
ArgumentDef
(
const
std
::
string
&
arg_name
,
const
ValueType
&
arg_type
,
int
arg_size
,
bool
arg_keyword_only
,
bool
arg_optional
)
:
name
(
arg_name
),
type
(
arg_type
),
size
(
arg_size
),
keyword_only
(
arg_keyword_only
),
optional
(
arg_optional
),
has_default_value
(
false
)
{}
template
<
typename
T
>
ArgumentDef
(
const
std
::
string
&
arg_name
,
const
T
&
arg_val
,
int
arg_size
,
bool
arg_keyword_only
,
bool
arg_optional
)
:
name
(
arg_name
),
type
(
ValueTypeOf
<
T
>
()),
size
(
arg_size
),
keyword_only
(
arg_keyword_only
),
optional
(
arg_optional
),
has_default_value
(
true
)
{
default_value
=
std
::
make_shared
<
detail
::
TypedDefaultVal
<
T
>>
(
arg_val
);
}
std
::
string
name
;
ValueType
type
;
int
size
;
bool
keyword_only
;
bool
optional
;
bool
has_default_value
;
std
::
shared_ptr
<
const
detail
::
DefaultVal
>
default_value
;
};
struct
FunctionDef
{
std
::
string
name
;
ReturnDef
return_def
;
std
::
vector
<
ArgumentDef
>
argument_def
;
};
}
// namespace functional
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_FUNCTIONAL_FUNCTION_DEF_H_
oneflow/api/python/functional/indexing.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/functional/indexing.h"
#include <object.h>
#include <pybind11/pybind11.h>
#include "oneflow/api/python/functional/common.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/api/python/functional/tensor_api.yaml.h"
#include "oneflow/core/common/foreign_lock_helper.h"
namespace
oneflow
{
namespace
one
{
namespace
functional
{
namespace
detail
{
void
PySliceUnpack
(
PyObject
*
object
,
Py_ssize_t
*
start
,
Py_ssize_t
*
stop
,
Py_ssize_t
*
step
)
{
PySliceObject
*
obj
=
(
PySliceObject
*
)
object
;
if
(
obj
->
step
==
Py_None
)
{
*
step
=
1
;
}
else
{
CHECK_OR_THROW
(
_PyEval_SliceIndex
(
obj
->
step
,
step
))
<<
"Invalid slice "
<<
PyObjectToReprStr
(
object
);
CHECK_NE_OR_THROW
(
*
step
,
0
)
<<
"slice step cannot be zero."
;
if
(
*
step
<
-
PY_SSIZE_T_MAX
)
*
step
=
-
PY_SSIZE_T_MAX
;
}
if
(
obj
->
start
==
Py_None
)
{
*
start
=
*
step
<
0
?
PY_SSIZE_T_MAX
:
0
;
}
else
{
CHECK_OR_THROW
(
_PyEval_SliceIndex
(
obj
->
start
,
start
))
<<
"Invalid slice "
<<
PyObjectToReprStr
(
object
);
}
if
(
obj
->
stop
==
Py_None
)
{
*
stop
=
*
step
<
0
?
PY_SSIZE_T_MIN
:
PY_SSIZE_T_MAX
;
}
else
{
CHECK_OR_THROW
(
_PyEval_SliceIndex
(
obj
->
stop
,
stop
))
<<
"Invalid slice "
<<
PyObjectToReprStr
(
object
);
}
}
DataType
InferScalarType
(
PyObject
*
object
)
{
if
(
PyBool_Check
(
object
))
{
return
DataType
::
kBool
;
}
else
if
(
PyLong_Check
(
object
))
{
return
DataType
::
kInt64
;
}
else
if
(
PyArray_Check
(
object
))
{
return
numpy
::
GetOFDataTypeFromNpArray
(
reinterpret_cast
<
PyArrayObject
*>
(
object
)).
GetOrThrow
();
}
else
if
(
PyArray_CheckScalar
(
object
))
{
return
numpy
::
NumpyTypeToOFDataType
(
PyArray_DescrFromScalar
(
object
)
->
type_num
).
GetOrThrow
();
}
else
if
(
PySequence_Check
(
object
))
{
int64_t
length
=
PySequence_Length
(
object
);
CHECK_GT_OR_THROW
(
length
,
0
)
<<
"Index should not be empty."
;
DataType
scalar_type
=
DataType
::
kInvalidDataType
;
for
(
int64_t
i
=
0
;
i
<
length
;
++
i
)
{
PyObjectPtr
item
(
PySequence_GetItem
(
object
,
i
));
const
auto
&
item_scalar_type
=
InferScalarType
(
item
.
get
());
if
(
scalar_type
!=
DataType
::
kInvalidDataType
)
{
CHECK_EQ_OR_THROW
(
scalar_type
,
item_scalar_type
)
<<
"Different scalar types are not allowed."
;
}
else
{
scalar_type
=
item_scalar_type
;
}
}
return
scalar_type
;
}
THROW
(
TypeError
)
<<
"Can't infer scalar type of "
<<
Py_TYPE
(
object
)
->
tp_name
;
return
DataType
::
kInvalidDataType
;
}
void
ParseScalar
(
PyObject
*
object
,
char
*
data
,
const
DataType
&
dtype
)
{
if
(
dtype
==
DataType
::
kInt64
)
{
CHECK_OR_THROW
(
PyLong_Check
(
object
)
||
numpy
::
PyArrayCheckLongScalar
(
object
))
<<
"Expected a long value."
;
*
(
reinterpret_cast
<
int64_t
*>
(
data
))
=
PyLong_AsLongLong
(
object
);
}
else
if
(
dtype
==
DataType
::
kInt32
)
{
CHECK_OR_THROW
(
PyLong_Check
(
object
)
||
numpy
::
PyArrayCheckLongScalar
(
object
))
<<
"Expected a long value."
;
*
(
reinterpret_cast
<
int32_t
*>
(
data
))
=
PyLong_AsLongLong
(
object
);
}
else
if
(
dtype
==
DataType
::
kUInt8
||
dtype
==
DataType
::
kBool
)
{
CHECK_OR_THROW
(
PyBool_Check
(
object
)
||
PyLong_Check
(
object
)
||
numpy
::
PyArrayCheckLongScalar
(
object
))
<<
"Expected a boolean or long value."
;
if
(
PyBool_Check
(
object
)
||
numpy
::
PyArrayCheckBoolScalar
(
object
))
{
*
(
reinterpret_cast
<
bool
*>
(
data
))
=
(
object
==
Py_True
);
}
else
{
int64_t
value
=
PyLong_AsLongLong
(
object
);
CHECK_OR_THROW
(
value
>=
0
&&
value
<=
255
)
<<
"Out of range 0-255."
;
*
(
reinterpret_cast
<
uint8_t
*>
(
data
))
=
static_cast
<
uint8_t
>
(
value
);
}
}
else
{
THROW
(
TypeError
)
<<
"Can't parse scalar with data type "
<<
dtype
;
}
}
void
RecursiveParseAndAssign
(
PyObject
*
object
,
char
*
data
,
const
int
&
ndims
,
const
int
&
dim
,
const
ShapeView
&
shape
,
const
DimVector
&
strides
,
const
DataType
&
dtype
)
{
if
(
dim
==
ndims
)
{
return
ParseScalar
(
object
,
data
,
dtype
);
}
auto
seq
=
PyObjectPtr
(
PySequence_Fast
(
object
,
"Expected a sequence."
));
int64_t
size
=
PySequence_Fast_GET_SIZE
(
seq
.
get
());
CHECK_EQ_OR_THROW
(
size
,
shape
.
At
(
dim
))
<<
"Sequence size is "
<<
size
<<
" at dimemsion "
<<
dim
<<
", but expected "
<<
shape
.
At
(
dim
);
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
PyObject
*
item
=
PySequence_Fast_GET_ITEM
(
seq
.
get
(),
i
);
RecursiveParseAndAssign
(
item
,
data
,
ndims
,
dim
+
1
,
shape
,
strides
,
dtype
);
data
+=
strides
.
at
(
dim
)
*
GetSizeOfDataType
(
dtype
);
}
}
void
ParseArrayToBlob
(
PyObject
*
object
,
Blob
*
blob
)
{
const
DataType
dtype
=
blob
->
data_type
();
const
int
ndims
=
blob
->
shape
().
NumAxes
();
DimVector
strides
(
ndims
);
int64_t
size
=
1
;
for
(
int
i
=
ndims
-
1
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
size
;
size
*=
blob
->
shape
().
At
(
i
);
}
RecursiveParseAndAssign
(
object
,
blob
->
mut_dptr
<
char
>
(),
ndims
,
0
,
blob
->
shape
(),
strides
,
dtype
);
}
Shape
InferArraySizes
(
PyObject
*
object
)
{
DimVector
sizes
;
PyObject
*
seq
=
object
;
PyObjectPtr
handle
;
while
(
PySequence_Check
(
seq
))
{
int64_t
length
=
PySequence_Length
(
seq
);
CHECK_GT_OR_THROW
(
length
,
0
)
<<
"Index should not be empty."
;
sizes
.
emplace_back
(
length
);
CHECK_LE_OR_THROW
(
sizes
.
size
(),
/*MAX_DIMS=*/
128
)
<<
"Too many dimensions "
<<
Py_TYPE
(
seq
)
->
tp_name
;
if
(
length
==
0
)
break
;
handle
=
PyObjectPtr
(
PySequence_GetItem
(
seq
,
0
));
seq
=
handle
.
get
();
}
return
Shape
(
sizes
);
}
Maybe
<
Tensor
>
ConvertToIndexingTensor
(
PyObject
*
object
)
{
const
DataType
dtype
=
InferScalarType
(
object
);
const
auto
&
device
=
JUST
(
Device
::
New
(
"cpu"
));
// index type must be integers
if
(
!
(
IsIntegralDataType
(
dtype
)
||
(
IsBoolDataType
(
dtype
))))
{
return
Error
::
IndexError
()
<<
"only integers, slices (`:`), ellipsis (`...`), numpy.newaxis "
"(`None`) and integer or boolean arrays are valid indices"
;
}
// In advanced indexing condition, index can be array object, need to handle it specially.
if
(
PyArray_Check
(
object
))
{
return
TensorWithData
(
object
,
NullOpt
,
device
,
false
,
/*pin_memory=*/
false
);
}
const
auto
&
sizes
=
InferArraySizes
(
object
);
const
auto
&
tensor
=
JUST
(
functional
::
Empty
(
sizes
,
CHECK_JUST
(
DType
::
Get
(
dtype
)),
device
,
/*pin_memory=*/
false
));
// Prevent the python object release until the callback is complete.
Py_INCREF
(
object
);
auto
handle
=
std
::
shared_ptr
<
PyObject
>
(
PyObjectPtr
(
object
));
JUST
(
PhysicalRun
([
&
](
InstructionsBuilder
*
builder
)
->
Maybe
<
void
>
{
return
builder
->
AccessBlobByCallback
(
JUST
(
tensor
->
AsMirroredTensor
()),
[
handle
](
uint64_t
ofblob_ptr
)
{
auto
*
of_blob
=
reinterpret_cast
<
OfBlob
*>
(
ofblob_ptr
);
CHECK_JUST
(
Singleton
<
ForeignLockHelper
>::
Get
()
->
WithScopedAcquire
([
&
]()
->
Maybe
<
void
>
{
ParseArrayToBlob
(
handle
.
get
(),
of_blob
->
mut_blob
());
return
Maybe
<
void
>::
Ok
();
}));
},
"mut"
);
}));
return
tensor
;
}
IndexItem
UnpackIndexItem
(
PyObject
*
object
)
{
if
(
object
==
Py_Ellipsis
)
{
return
IndexItem
(
EllipsisIndex
{});
}
else
if
(
PySlice_Check
(
object
))
{
Py_ssize_t
start
,
end
,
step
;
PySliceUnpack
(
object
,
&
start
,
&
end
,
&
step
);
return
IndexItem
(
start
,
end
,
step
);
}
else
if
(
PyLong_Check
(
object
)
&&
object
!=
Py_False
&&
object
!=
Py_True
)
{
return
IndexItem
(
static_cast
<
int64_t
>
(
PyLong_AsLongLong
(
object
)));
}
else
if
(
numpy
::
PyArrayCheckLongScalar
(
object
))
{
return
IndexItem
(
static_cast
<
int64_t
>
(
PyLong_AsLongLong
(
object
)));
}
else
if
(
object
==
Py_False
||
object
==
Py_True
)
{
return
IndexItem
(
object
==
Py_True
);
}
else
if
(
object
==
Py_None
)
{
return
IndexItem
(
NoneIndex
{});
}
else
if
(
PyTensor_Check
(
object
))
{
return
IndexItem
(
PyTensor_Unpack
(
object
));
}
else
if
(
PySequence_Check
(
object
))
{
return
IndexItem
(
ConvertToIndexingTensor
(
object
).
GetPtrOrThrow
());
}
THROW
(
TypeError
)
<<
"Invalid index "
<<
Py_TYPE
(
object
)
->
tp_name
;
return
IndexItem
();
}
}
// namespace detail
}
// namespace functional
}
// namespace one
}
// namespace oneflow
oneflow/api/python/functional/indexing.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_INDEXING_H_
#define ONEFLOW_API_PYTHON_FUNCTIONAL_INDEXING_H_
#include <Python.h>
#include "oneflow/api/python/functional/common.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/functional/tensor_index.h"
namespace
oneflow
{
namespace
one
{
namespace
functional
{
namespace
detail
{
void
PySliceUnpack
(
PyObject
*
object
,
Py_ssize_t
*
start
,
Py_ssize_t
*
stop
,
Py_ssize_t
*
step
);
Maybe
<
Tensor
>
ConvertToIndexingTensor
(
PyObject
*
object
);
IndexItem
UnpackIndexItem
(
PyObject
*
object
);
}
// namespace detail
}
// namespace functional
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_FUNCTIONAL_INDEXING_H_
oneflow/api/python/functional/python_arg.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/functional/python_arg.h"
#include "oneflow/api/python/framework/tensor.h"
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/functional/indexing.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/functional/tensor_index.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
one
{
namespace
functional
{
#define INSTANCE_OBJECT_AS_INTEGER(T) \
template<> \
T PythonArg::ObjectAs<T>() const { \
return static_cast<T>(PyLong_AsLongLong(object_)); \
} \
template<> \
std::vector<T> PythonArg::ObjectAs<std::vector<T>>() const { \
if (size_ > 0 && PyLong_Check(object_)) { \
return std::vector<T>(size_, static_cast<T>(PyLong_AsLongLong(object_))); \
} \
return PyUnpackLongSequence<T>(object_); \
} \
template<> \
std::shared_ptr<std::vector<T>> PythonArg::ObjectAs<std::shared_ptr<std::vector<T>>>() const { \
return std::make_shared<std::vector<T>>(ObjectAs<std::vector<T>>()); \
}
OF_PP_FOR_EACH_TUPLE
(
INSTANCE_OBJECT_AS_INTEGER
,
INTEGER_TYPE_SEQ
)
#undef INSTANCE_OBJECT_AS_INTEGER
#define INSTANCE_OBJECT_AS_FLOAT(T) \
template<> \
T PythonArg::ObjectAs<T>() const { \
return static_cast<T>(PyFloat_AsDouble(object_)); \
} \
template<> \
std::vector<T> PythonArg::ObjectAs<std::vector<T>>() const { \
if (size_ > 0 && PyFloat_Check(object_)) { \
return std::vector<T>(size_, static_cast<T>(PyFloat_AsDouble(object_))); \
} \
return PyUnpackFloatSequence<T>(object_); \
} \
template<> \
std::shared_ptr<std::vector<T>> PythonArg::ObjectAs<std::shared_ptr<std::vector<T>>>() const { \
return std::make_shared<std::vector<T>>(ObjectAs<std::vector<T>>()); \
}
OF_PP_FOR_EACH_TUPLE
(
INSTANCE_OBJECT_AS_FLOAT
,
FLOATING_TYPE_SEQ
)
#undef INSTANCE_OBJECT_AS_FLOAT
#define INSTANCE_OBJECT_AS_SHARED_PTR(T) \
template<> \
std::shared_ptr<T> PythonArg::ObjectAs<std::shared_ptr<T>>() const { \
return std::make_shared<T>(ObjectAs<T>()); \
}
template
<
>
std
::
string
PythonArg
::
ObjectAs
<
std
::
string
>
()
const
{
return
PyStringAsString
(
object_
);
}
INSTANCE_OBJECT_AS_SHARED_PTR
(
std
::
string
)
template
<
>
Scalar
PythonArg
::
ObjectAs
<
Scalar
>
()
const
{
return
PyUnpackScalar
(
object_
);
}
INSTANCE_OBJECT_AS_SHARED_PTR
(
Scalar
)
template
<
>
std
::
shared_ptr
<
one
::
Tensor
>
PythonArg
::
ObjectAs
<
std
::
shared_ptr
<
one
::
Tensor
>>
()
const
{
return
PyTensor_Unpack
(
object_
);
}
template
<
>
one
::
TensorTuple
PythonArg
::
ObjectAs
<
one
::
TensorTuple
>
()
const
{
if
(
PyTensorTupleCheck
(
object_
))
{
return
*
PyUnpackTensorTuple
(
object_
);
}
const
auto
&
v
=
PyUnpackTensorSequence
(
object_
);
one
::
TensorTuple
values
(
v
.
size
());
for
(
int
i
=
0
;
i
<
v
.
size
();
++
i
)
{
values
[
i
]
=
v
.
at
(
i
);
}
return
values
;
}
INSTANCE_OBJECT_AS_SHARED_PTR
(
one
::
TensorTuple
)
template
<
>
Symbol
<
DType
>
PythonArg
::
ObjectAs
<
Symbol
<
DType
>>
()
const
{
return
PyUnpackDType
(
object_
);
}
template
<
>
std
::
vector
<
Symbol
<
DType
>>
PythonArg
::
ObjectAs
<
std
::
vector
<
Symbol
<
DType
>>>
()
const
{
return
PyUnpackDTypeSequence
(
object_
);
}
INSTANCE_OBJECT_AS_SHARED_PTR
(
std
::
vector
<
Symbol
<
DType
>>
)
template
<
>
Shape
PythonArg
::
ObjectAs
<
Shape
>
()
const
{
const
auto
&
shape
=
PyUnpackLongSequence
<
int64_t
>
(
object_
);
return
Shape
(
DimVector
(
shape
.
begin
(),
shape
.
end
()));
}
INSTANCE_OBJECT_AS_SHARED_PTR
(
Shape
)
template
<
>
std
::
vector
<
Shape
>
PythonArg
::
ObjectAs
<
std
::
vector
<
Shape
>>
()
const
{
return
PyUnpackShapeSequence
(
object_
);
}
INSTANCE_OBJECT_AS_SHARED_PTR
(
std
::
vector
<
Shape
>
)
template
<
>
std
::
shared_ptr
<
one
::
Generator
>
PythonArg
::
ObjectAs
<
std
::
shared_ptr
<
one
::
Generator
>>
()
const
{
return
PyUnpackGenerator
(
object_
);
}
template
<
>
Symbol
<
Device
>
PythonArg
::
ObjectAs
<
Symbol
<
Device
>>
()
const
{
if
(
PyStringCheck
(
object_
))
{
std
::
string
device_str
=
PyStringAsString
(
object_
);
return
Device
::
ParseAndNew
(
device_str
).
GetOrThrow
();
}
return
PyUnpackDevice
(
object_
);
}
template
<
>
Symbol
<
ParallelDesc
>
PythonArg
::
ObjectAs
<
Symbol
<
ParallelDesc
>>
()
const
{
return
PyUnpackParallelDesc
(
object_
);
}
template
<
>
Symbol
<
SbpParallel
>
PythonArg
::
ObjectAs
<
Symbol
<
SbpParallel
>>
()
const
{
return
PyUnpackSbpParallel
(
object_
);
}
template
<
>
std
::
vector
<
Symbol
<
SbpParallel
>>
PythonArg
::
ObjectAs
<
std
::
vector
<
Symbol
<
SbpParallel
>>>
()
const
{
if
(
PySbpParallelCheck
(
object_
))
{
return
std
::
vector
<
Symbol
<
SbpParallel
>>
(
1
,
PyUnpackSbpParallel
(
object_
));
}
return
PyUnpackSbpParallelSequence
(
object_
);
}
INSTANCE_OBJECT_AS_SHARED_PTR
(
std
::
vector
<
Symbol
<
SbpParallel
>>
)
template
<
>
TensorIndex
PythonArg
::
ObjectAs
<
TensorIndex
>
()
const
{
return
PyUnpackTensorIndex
(
object_
);
}
INSTANCE_OBJECT_AS_SHARED_PTR
(
TensorIndex
)
template
<
>
std
::
shared_ptr
<
one
::
OpExpr
>
PythonArg
::
ObjectAs
<
std
::
shared_ptr
<
one
::
OpExpr
>>
()
const
{
return
PyUnpackOpExpr
(
object_
);
}
template
<
>
PyObject
*
PythonArg
::
ObjectAs
<
PyObject
*>
()
const
{
return
object_
;
}
template
<
>
std
::
vector
<
std
::
string
>
PythonArg
::
ObjectAs
<
std
::
vector
<
std
::
string
>>
()
const
{
return
PyUnpackSequence
<
std
::
string
>
(
object_
,
[](
PyObject
*
item
)
->
std
::
string
{
return
PyStringAsString
(
item
);
});
}
INSTANCE_OBJECT_AS_SHARED_PTR
(
std
::
vector
<
std
::
string
>
)
#undef INSTANCE_OBJECT_AS_SHARED_PTR
bool
PythonArg
::
TypeCheck
(
ValueType
type
)
const
{
if
(
tag_
==
HAS_DEFAULT
)
{
return
default_val_
->
value_type
()
==
type
;
}
switch
(
type
)
{
case
kINT32
:
case
kUINT32
:
case
kINT64
:
case
kUINT64
:
case
kBOOL
:
return
PyLong_Check
(
object_
)
||
numpy
::
PyArrayCheckLongScalar
(
object_
);
case
kINT32_LIST
:
case
kUINT32_LIST
:
case
kINT64_LIST
:
case
kUINT64_LIST
:
case
kBOOL_LIST
:
return
PyLongSequenceCheck
(
object_
)
||
(
size_
>
0
&&
PyLong_Check
(
object_
));
case
kFLOAT
:
case
kDOUBLE
:
return
PyFloat_Check
(
object_
)
||
PyLong_Check
(
object_
)
||
numpy
::
PyArrayCheckFloatScalar
(
object_
)
||
numpy
::
PyArrayCheckLongScalar
(
object_
);
case
kFLOAT_LIST
:
case
kDOUBLE_LIST
:
return
PyFloatSquenceCheck
(
object_
)
||
(
size_
>
0
&&
(
PyFloat_Check
(
object_
)
||
PyLong_Check
(
object_
)));
case
kSTRING
:
return
PyStringCheck
(
object_
);
case
kSTRING_LIST
:
return
PyStringSequenceCheck
(
object_
);
case
kSCALAR
:
return
PyScalarCheck
(
object_
)
||
numpy
::
PyArrayCheckLongScalar
(
object_
)
||
numpy
::
PyArrayCheckFloatScalar
(
object_
);
case
kTENSOR
:
case
kTENSOR_REF
:
return
PyTensor_Check
(
object_
);
case
kTENSOR_TUPLE
:
return
PyTensorTupleCheck
(
object_
)
||
PyTensorSequenceCheck
(
object_
);
case
kDTYPE
:
return
PyDTypeCheck
(
object_
);
case
kSHAPE
:
return
PyLongSequenceCheck
(
object_
);
case
kGENERATOR
:
case
kGENERATOR_REF
:
return
PyGeneratorCheck
(
object_
);
case
kTENSOR_INDEX
:
return
PyTensorIndexCheck
(
object_
);
case
kDEVICE
:
return
PyDeviceCheck
(
object_
)
||
PyStringCheck
(
object_
);
case
kPARALLEL_DESC
:
return
PyParallelDescCheck
(
object_
);
case
kSBP_PARALLEL
:
return
PySbpParallelCheck
(
object_
);
case
kSBP_PARALLEL_LIST
:
return
PySbpParallelSequenceCheck
(
object_
)
||
PySbpParallelCheck
(
object_
);
case
kOPEXPR_REF
:
return
PyOpExprCheck
(
object_
);
case
kPY_OBJECT
:
return
nullptr
!=
object_
;
case
kDTYPE_LIST
:
return
PyDTypeSequenceCheck
(
object_
);
case
kSHAPE_LIST
:
return
PyShapeSequenceCheck
(
object_
);
default:
{
THROW
(
RuntimeError
)
<<
"Can not check type "
<<
ValueTypeName
(
type
);
}
}
return
false
;
}
bool
PythonArgCheck
(
const
PythonArg
&
arg
,
ValueType
type
)
{
return
arg
.
TypeCheck
(
type
);
}
}
// namespace functional
}
// namespace one
}
// namespace oneflow
oneflow/api/python/functional/python_arg.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_H_
#define ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_H_
#include <pybind11/pybind11.h>
#include <Python.h>
#include "oneflow/core/common/throw.h"
#include "oneflow/api/python/functional/value_types.h"
#include "oneflow/core/common/maybe.h"
namespace
py
=
pybind11
;
namespace
oneflow
{
namespace
one
{
namespace
functional
{
namespace
detail
{
struct
DefaultVal
{
virtual
ValueType
value_type
()
const
=
0
;
virtual
const
void
*
Ptr
()
const
=
0
;
};
template
<
typename
T
>
struct
TypedDefaultVal
final
:
public
DefaultVal
{
T
content
;
explicit
TypedDefaultVal
(
const
T
&
v
)
:
content
(
v
)
{}
ValueType
value_type
()
const
override
{
return
ValueTypeOf
<
T
>
();
}
const
void
*
Ptr
()
const
override
{
return
&
content
;
}
};
template
<
typename
T
>
struct
optional_traits
{
using
type
=
void
;
};
template
<
typename
T
>
struct
optional_traits
<
Optional
<
T
>>
{
using
type
=
decltype
(
std
::
declval
<
Optional
<
T
>>
().
Data_YouAreNotAllowedToCallThisFuncOutsideThisFile
());
};
}
// namespace detail
class
PythonArg
{
public:
PythonArg
()
=
default
;
PythonArg
(
const
py
::
object
&
object
,
int
size
=
0
)
:
PythonArg
(
object
.
ptr
(),
size
)
{}
PythonArg
(
PyObject
*
object
,
int
size
=
0
)
:
object_
(
object
),
default_val_
(),
size_
(
size
),
tag_
(
HAS_OBJECT
)
{}
PythonArg
(
const
std
::
shared_ptr
<
const
detail
::
DefaultVal
>&
value
,
int
size
=
0
)
:
object_
(
nullptr
),
default_val_
(
value
),
size_
(
size
),
tag_
(
HAS_DEFAULT
)
{}
template
<
typename
T
,
typename
std
::
enable_if
<!
py
::
detail
::
is_pyobject
<
T
>
::
value
,
int
>::
type
=
0
>
PythonArg
(
const
T
&
value
,
int
size
=
0
)
:
object_
(
nullptr
),
default_val_
(
std
::
make_shared
<
detail
::
TypedDefaultVal
<
T
>>
(
value
)),
size_
(
size
),
tag_
(
HAS_DEFAULT
)
{}
virtual
~
PythonArg
()
=
default
;
template
<
typename
T
,
typename
std
::
enable_if
<!
internal
::
IsOptional
<
T
>
::
value
,
int
>::
type
=
0
>
T
As
()
const
{
if
(
tag_
==
HAS_DEFAULT
)
{
CHECK_EQ_OR_THROW
(
ValueTypeOf
<
T
>
(),
default_val_
->
value_type
())
<<
"Could not convert default value from type "
<<
default_val_
->
value_type
()
<<
" to type "
<<
ValueTypeOf
<
T
>
();
return
*
reinterpret_cast
<
const
T
*>
(
default_val_
->
Ptr
());
}
CHECK_EQ_OR_THROW
(
tag_
,
HAS_OBJECT
);
return
ObjectAs
<
oneflow
::
detail
::
remove_cvref_t
<
T
>>
();
}
template
<
typename
T
,
typename
std
::
enable_if
<
internal
::
IsOptional
<
T
>
::
value
,
int
>::
type
=
0
>
T
As
()
const
{
if
(
tag_
==
HAS_DEFAULT
)
{
CHECK_EQ_OR_THROW
(
ValueTypeOf
<
T
>
(),
default_val_
->
value_type
())
<<
"Could not convert default value from type "
<<
default_val_
->
value_type
()
<<
" to type "
<<
ValueTypeOf
<
T
>
();
return
*
reinterpret_cast
<
const
T
*>
(
default_val_
->
Ptr
());
}
CHECK_EQ_OR_THROW
(
tag_
,
HAS_OBJECT
);
if
(
object_
==
Py_None
)
{
return
T
();
}
return
ObjectAs
<
typename
detail
::
optional_traits
<
T
>::
type
>
();
}
bool
TypeCheck
(
ValueType
type
)
const
;
private:
template
<
typename
T
>
T
ObjectAs
()
const
;
PyObject
*
object_
;
std
::
shared_ptr
<
const
detail
::
DefaultVal
>
default_val_
;
size_t
size_
;
enum
{
HAS_OBJECT
,
HAS_DEFAULT
,
HAS_NONE
}
tag_
;
};
bool
PythonArgCheck
(
const
PythonArg
&
arg
,
ValueType
type
);
}
// namespace functional
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_H_
oneflow/api/python/functional/python_arg_parser.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/functional/python_arg_parser.h"
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/functional/python_arg.h"
namespace
oneflow
{
namespace
one
{
namespace
functional
{
void
FunctionSchema
::
ReportKwargsError
(
PyObject
*
kwargs
,
size_t
nargs
)
const
{
PyObject
*
key
=
nullptr
,
*
value
=
nullptr
;
Py_ssize_t
pos
=
0
;
while
(
PyDict_Next
(
kwargs
,
&
pos
,
&
key
,
&
value
))
{
if
(
!
PyStringCheck
(
key
))
{
THROW
(
TypeError
)
<<
def_
->
name
<<
"(): keywords must be strings"
;
}
int64_t
index
=
-
1
;
const
std
::
string
string_key
=
PyStringAsString
(
key
);
for
(
int
i
=
0
;
i
<
def_
->
argument_def
.
size
();
++
i
)
{
const
auto
&
arg
=
def_
->
argument_def
.
at
(
i
);
if
(
arg
.
name
==
string_key
)
{
index
=
i
;
break
;
}
}
if
(
index
<
0
)
{
THROW
(
TypeError
)
<<
def_
->
name
<<
"(): got an unexpected keyword argument '"
<<
string_key
<<
"'"
;
}
if
(
index
<
nargs
)
{
THROW
(
TypeError
)
<<
def_
->
name
<<
"(): got multiple values for argument '"
<<
string_key
<<
"'"
;
}
}
THROW
(
TypeError
)
<<
def_
->
name
<<
"(): kwargs unknown error"
;
}
// The argument parsing refers to the implementation of Pytorch.
bool
FunctionSchema
::
Parse
(
PyObject
*
args
,
PyObject
*
kwargs
,
PythonArg
*
parsed_args
,
bool
raise_exception
)
const
{
bool
treat_args_as_list
=
false
;
size_t
nargs
=
args
?
PyTuple_Size
(
args
)
:
0
;
size_t
remaining_kwargs
=
kwargs
?
PyDict_Size
(
kwargs
)
:
0
;
if
(
max_pos_nargs_
==
1
)
{
const
auto
&
type
=
def_
->
argument_def
.
at
(
0
).
type
;
treat_args_as_list
=
IsIntegralListType
(
type
)
||
type
==
kSHAPE
||
type
==
kTENSOR_TUPLE
;
}
if
(
nargs
>
max_pos_nargs_
&&
!
treat_args_as_list
)
{
if
(
raise_exception
)
{
THROW
(
TypeError
)
<<
def_
->
name
<<
"(): takes "
<<
max_pos_nargs_
<<
" positional arguments but "
<<
nargs
<<
" were given"
;
}
return
false
;
}
int
arg_pos
=
0
;
for
(
int
i
=
0
;
i
<
def_
->
argument_def
.
size
();
++
i
)
{
const
auto
&
param
=
def_
->
argument_def
.
at
(
i
);
PyObject
*
obj
=
NULL
;
if
(
args
&&
arg_pos
<
nargs
)
{
if
(
param
.
keyword_only
)
{
if
(
raise_exception
)
{
THROW
(
TypeError
)
<<
def_
->
name
<<
"(): argument '"
<<
param
.
name
<<
"' is keyword only"
;
}
return
false
;
}
obj
=
PyTuple_GetItem
(
args
,
arg_pos
);
}
else
if
(
kwargs
)
{
obj
=
PyDict_GetItemString
(
kwargs
,
param
.
name
.
c_str
());
if
(
obj
)
{
remaining_kwargs
--
;
}
}
if
(
obj
)
{
if
(
arg_pos
==
0
&&
treat_args_as_list
&&
!
param
.
keyword_only
&&
(
PyLong_Check
(
obj
)
||
PyTensor_Check
(
obj
)))
{
obj
=
args
;
arg_pos
=
nargs
;
}
else
{
arg_pos
++
;
}
PythonArg
arg
(
obj
,
param
.
size
);
if
((
obj
==
Py_None
&&
param
.
optional
)
||
PythonArgCheck
(
arg
,
param
.
type
))
{
parsed_args
[
i
]
=
arg
;
}
else
{
if
(
raise_exception
)
{
THROW
(
TypeError
)
<<
def_
->
name
<<
"(): argument '"
<<
param
.
name
<<
"' must be "
<<
ValueTypeName
(
param
.
type
)
<<
", not "
<<
PyStringAsString
(
PyObject_Str
((
PyObject
*
)
Py_TYPE
(
obj
)));
}
return
false
;
}
}
else
{
if
(
!
param
.
has_default_value
)
{
if
(
raise_exception
)
{
THROW
(
TypeError
)
<<
def_
->
name
<<
"(): missing required argument "
<<
param
.
name
;
}
return
false
;
}
parsed_args
[
i
]
=
param
.
default_value
;
}
}
if
(
remaining_kwargs
>
0
)
{
if
(
raise_exception
)
{
ReportKwargsError
(
kwargs
,
nargs
);
}
return
false
;
}
return
true
;
}
}
// namespace functional
}
// namespace one
}
// namespace oneflow
oneflow/api/python/functional/python_arg_parser.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_PARSER_H_
#define ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_PARSER_H_
#include <Python.h>
#include "oneflow/api/python/functional/function_def.h"
#include "oneflow/api/python/functional/python_arg.h"
#include "oneflow/core/common/throw.h"
#include "oneflow/core/common/util.h"
namespace
oneflow
{
namespace
one
{
namespace
functional
{
template
<
int
N
>
class
ParsedArgs
{
public:
ParsedArgs
()
=
default
;
const
PythonArg
&
operator
[](
size_t
idx
)
const
{
return
data
[
idx
];
}
PythonArg
&
operator
[](
size_t
idx
)
{
return
data
[
idx
];
}
public:
PythonArg
data
[
N
];
};
class
FunctionSchema
{
public:
FunctionSchema
()
=
default
;
FunctionSchema
(
const
std
::
string
&
signature
,
const
FunctionDef
*
def
,
size_t
max_pos_nargs
)
:
signature_
(
signature
),
def_
(
def
),
max_pos_nargs_
(
max_pos_nargs
)
{}
const
std
::
string
&
signature
()
const
{
return
signature_
;
}
bool
Parse
(
PyObject
*
args
,
PyObject
*
kwargs
,
PythonArg
*
parsed_args
,
bool
raise_exception
)
const
;
private:
void
ReportKwargsError
(
PyObject
*
kwargs
,
size_t
nargs
)
const
;
std
::
string
signature_
;
const
FunctionDef
*
def_
;
size_t
max_pos_nargs_
;
};
template
<
typename
...
SchemaT
>
class
PythonArgParser
{
public:
static_assert
(
sizeof
...(
SchemaT
)
>=
1
,
"requires 1 template argument at least."
);
static
constexpr
size_t
kSchemaSize
=
sizeof
...(
SchemaT
);
static
constexpr
size_t
N
=
std
::
max
({
SchemaT
::
max_args
...});
template
<
size_t
I
>
using
schema_t
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
SchemaT
...
>>::
type
;
PythonArgParser
(
const
std
::
string
&
name
)
:
name_
(
name
)
{
Init
(
std
::
make_index_sequence
<
sizeof
...(
SchemaT
)
>
{});
}
int
Parse
(
PyObject
*
args
,
PyObject
*
kwargs
,
ParsedArgs
<
N
>*
parsed_args
)
const
{
bool
raise_exception
=
(
kSchemaSize
==
1
);
for
(
int
i
=
0
;
i
<
kSchemaSize
;
++
i
)
{
if
(
schema_
[
i
].
Parse
(
args
,
kwargs
,
parsed_args
->
data
,
raise_exception
))
{
return
i
;
}
}
ReportInvalidArgsError
(
args
,
kwargs
);
return
-
1
;
}
private:
template
<
size_t
...
I
>
void
Init
(
std
::
index_sequence
<
I
...
>
)
{
__attribute__
((
__unused__
))
int
dummy
[]
=
{
((
void
)(
schema_
[
I
]
=
FunctionSchema
(
schema_t
<
I
>::
signature
,
&
schema_t
<
I
>::
function_def
,
schema_t
<
I
>::
max_pos_args
)),
0
)...};
}
void
ReportInvalidArgsError
(
PyObject
*
args
,
PyObject
*
kwargs
)
const
{
std
::
ostringstream
ss
;
ss
<<
name_
<<
"(): received an invalid combination of arguments. The valid signatures are:"
;
for
(
int
i
=
0
;
i
<
kSchemaSize
;
++
i
)
{
ss
<<
"
\n\t
*"
<<
i
<<
": "
<<
schema_
[
i
].
signature
();
}
THROW
(
TypeError
)
<<
ss
.
str
();
}
private:
std
::
string
name_
;
FunctionSchema
schema_
[
kSchemaSize
];
};
}
// namespace functional
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_ARG_PARSER_H_
oneflow/api/python/functional/python_frame.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_FRAME_H_
#define ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_FRAME_H_
#include <Python.h>
#include "oneflow/api/python/functional/common.h"
#include "oneflow/core/framework/op_interpreter/dispatch_frame.h"
#include "oneflow/core/job/graph_scope_vars.h"
namespace
oneflow
{
namespace
one
{
namespace
functional
{
namespace
{
std
::
string
get_cur_frame_stack_str
(
int32_t
max_stack_depth
)
{
std
::
string
cur_f_str
;
PyFrameObject
*
cur_frame
=
PyEval_GetFrame
();
for
(
int32_t
i
=
0
;
i
<
max_stack_depth
;
i
++
)
{
if
(
cur_frame
==
NULL
)
break
;
const
int32_t
stack_index
=
(
-
1
)
*
i
-
1
;
cur_f_str
=
"Python Stack["
+
std
::
to_string
(
stack_index
)
+
"]: "
+
PyObjectToReprStr
((
PyObject
*
)
cur_frame
)
+
"; "
+
cur_f_str
;
cur_frame
=
cur_frame
->
f_back
;
}
return
cur_f_str
;
}
int32_t
get_cur_stack_depth
()
{
int32_t
current_stack_depth
=
0
;
PyFrameObject
*
f
=
PyEval_GetFrame
();
while
(
f
)
{
current_stack_depth
++
;
f
=
f
->
f_back
;
}
return
current_stack_depth
;
}
std
::
string
get_cur_frame_stack_str
()
{
const
bool
debug_mode
=
GetGraphDebugMode
();
const
int32_t
max_stack_depth
=
GetGraphDebugMaxPyStackDepth
();
if
(
debug_mode
)
{
// show more info for the stack trace in debug mode
int32_t
current_stack_depth
=
get_cur_stack_depth
();
std
::
string
cur_f_str
=
get_cur_frame_stack_str
(
max_stack_depth
);
if
(
current_stack_depth
>
max_stack_depth
)
{
// show how many stack depth remaining to be shown
int32_t
remaining_stack_depth
=
current_stack_depth
-
max_stack_depth
;
cur_f_str
+=
" ... "
+
std
::
to_string
(
remaining_stack_depth
)
+
" more; "
;
}
return
cur_f_str
;
}
return
get_cur_frame_stack_str
(
max_stack_depth
);
}
}
// namespace
class
PythonFrameGuard
{
public:
PythonFrameGuard
()
{
if
(
OF_PREDICT_FALSE
(
LazyMode
::
is_enabled
()))
{
prev_frame_str_
=
DispatchFrame
::
get_str
();
DispatchFrame
::
set_str
(
get_cur_frame_stack_str
());
}
}
~
PythonFrameGuard
()
{
if
(
OF_PREDICT_FALSE
(
LazyMode
::
is_enabled
()))
{
DispatchFrame
::
set_str
(
prev_frame_str_
);
}
}
private:
std
::
string
prev_frame_str_
;
};
}
// namespace functional
}
// namespace one
}
// namespace oneflow
#endif // ONEFLOW_API_PYTHON_FUNCTIONAL_PYTHON_FRAME_H_
oneflow/api/python/functional/tensor_api.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <Python.h>
#include <memory>
#include "oneflow/api/python/utils/tensor_utils.h"
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/functional/tensor_api.yaml.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/common/foreign_lock_helper.h"
namespace
oneflow
{
namespace
one
{
namespace
functional
{
namespace
impl
{
class
TensorWithDataFunctor
{
public:
Maybe
<
Tensor
>
operator
()(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
const
Optional
<
Symbol
<
Device
>>&
device
,
const
bool
requires_grad
,
const
bool
pin_memory
)
const
{
// NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.
// even if in nn.Graph build (module forward function), if you create a flow.Tensor,
// its a eager tensor by Run functional::Empty() in LazyMode::Grad(false)
LazyMode
::
Guard
lazy_mode_disabled_guard
(
/*is_enabled*/
false
);
if
(
PyTensor_Check
(
data
))
{
// Throw warnings like pytorch.
auto
ret
=
PyErr_WarnEx
(
PyExc_UserWarning
,
"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() "
"or sourceTensor.clone().detach().requires_grad_(True), rather than "
"oneflow.tensor(sourceTensor)."
,
1
);
if
(
ret
!=
0
)
{
return
Error
::
RuntimeError
();
}
const
auto
&
other
=
PyTensor_Unpack
(
data
);
return
MakeTensorFromOtherTensor
(
other
,
dtype
,
device
,
requires_grad
,
pin_memory
);
}
else
{
// Make tensor from python sequence or numpy array.
return
MakeLocalTensorFromData
(
data
,
dtype
,
device
,
requires_grad
,
pin_memory
);
}
}
};
class
ConsistentTensorWithDataFunctor
{
public:
Maybe
<
Tensor
>
operator
()(
PyObject
*
data
,
const
Optional
<
Symbol
<
DType
>>&
dtype
,
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
,
const
bool
requires_grad
)
const
{
// NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.
LazyMode
::
Guard
lazy_mode_disabled_guard
(
/*is_enabled*/
false
);
JUST
(
CheckDeviceIdsIsValid
(
placement
));
if
(
PyTensor_Check
(
data
))
{
// Throw warnings like pytorch.
auto
ret
=
PyErr_WarnEx
(
PyExc_UserWarning
,
"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() "
"or sourceTensor.clone().detach().requires_grad_(True), rather than "
"oneflow.tensor(sourceTensor)."
,
1
);
if
(
ret
!=
0
)
{
return
Error
::
RuntimeError
();
}
const
auto
&
other
=
PyTensor_Unpack
(
data
);
return
MakeTensorFromOtherTensor
(
other
,
dtype
,
placement
,
sbp_tuple
,
requires_grad
);
}
// Make consistent tensor from python sequence or numpy array.
return
MakeConsistentTensorFromData
(
data
,
dtype
,
placement
,
sbp_tuple
,
requires_grad
);
}
};
class
TensorEmptyCtorFunctor
{
public:
Maybe
<
Tensor
>
operator
()(
const
Optional
<
Symbol
<
Device
>>&
device
)
const
{
Shape
shape
(
DimVector
{
0
});
return
TensorWithShapeCtor
(
shape
,
device
);
}
};
class
ConsistentTensorEmptyCtorFunctor
{
public:
Maybe
<
Tensor
>
operator
()(
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
)
const
{
Shape
shape
(
DimVector
{
0
});
JUST
(
CheckDeviceIdsIsValid
(
placement
));
return
ConsistentTensorWithShapeCtor
(
shape
,
placement
,
sbp_tuple
);
}
};
class
TensorWithOtherCtorFunctor
{
public:
Maybe
<
Tensor
>
operator
()(
const
std
::
shared_ptr
<
Tensor
>&
other
)
const
{
// NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.
LazyMode
::
Guard
lazy_mode_disabled_guard
(
/*is_enabled*/
false
);
bool
is_pinned
=
false
;
if
(
other
->
is_local
())
{
is_pinned
=
JUST
(
CHECK_JUST
(
other
->
AsMirroredTensor
())
->
is_pinned
());
}
return
MakeTensorFromOtherTensor
(
other
,
is_pinned
);
}
};
class
TensorWithDataCtorFunctor
{
public:
Maybe
<
Tensor
>
operator
()(
PyObject
*
data
,
const
Optional
<
Symbol
<
Device
>>&
device
)
const
{
// Treat the single long as shape.
if
(
PyLong_Check
(
data
))
{
int64_t
size
=
PyLong_AsLongLong
(
data
);
Shape
shape
(
DimVector
{
size
});
return
TensorWithShapeCtor
(
shape
,
device
);
}
if
(
TensorSize_Check
(
data
))
{
return
TensorWithShapeCtor
(
TensorSize_AsShape
(
data
),
device
);
}
// NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.
LazyMode
::
Guard
lazy_mode_disabled_guard
(
/*is_enabled*/
false
);
const
auto
&
dtype
=
DType
::
Float
();
if
(
PyTensor_Check
(
data
))
{
const
auto
&
other
=
PyTensor_Unpack
(
data
);
const
bool
pin_memory
=
other
->
is_local
()
?
JUST
(
JUST
(
other
->
AsMirroredTensor
())
->
is_pinned
())
:
false
;
return
MakeTensorFromOtherTensor
(
other
,
dtype
,
device
,
/*requires_grad=*/
false
,
/*pin_memory=*/
pin_memory
);
}
// Make tensor from python sequence or numpy array.
return
MakeLocalTensorFromData
(
data
,
dtype
,
device
,
/*requires_grad=*/
false
,
/*pin_memory=*/
false
);
}
};
class
ConsistentTensorWithDataCtorFunctor
{
public:
Maybe
<
Tensor
>
operator
()(
PyObject
*
data
,
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
)
const
{
JUST
(
CheckDeviceIdsIsValid
(
placement
));
// Treat the single long as shape.
if
(
PyLong_Check
(
data
))
{
int64_t
size
=
PyLong_AsLongLong
(
data
);
Shape
shape
(
DimVector
{
size
});
return
ConsistentTensorWithShapeCtor
(
shape
,
placement
,
sbp_tuple
);
}
if
(
TensorSize_Check
(
data
))
{
return
ConsistentTensorWithShapeCtor
(
TensorSize_AsShape
(
data
),
placement
,
sbp_tuple
);
}
// NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.
LazyMode
::
Guard
lazy_mode_disabled_guard
(
/*is_enabled*/
false
);
const
auto
&
dtype
=
DType
::
Float
();
if
(
PyTensor_Check
(
data
))
{
const
auto
&
other
=
PyTensor_Unpack
(
data
);
return
MakeTensorFromOtherTensor
(
other
,
dtype
,
placement
,
sbp_tuple
,
/*requires_grad=*/
false
);
}
// Make consistent tensor from python sequence or numpy array.
return
MakeConsistentTensorFromData
(
data
,
dtype
,
placement
,
sbp_tuple
,
/*requires_grad=*/
false
);
}
};
class
TensorWithShapeCtorFunctor
{
public:
Maybe
<
Tensor
>
operator
()(
const
Shape
&
shape
,
const
Optional
<
Symbol
<
Device
>>&
device
)
const
{
// NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.
LazyMode
::
Guard
lazy_mode_disabled_guard
(
/*is_enabled*/
false
);
Symbol
<
Device
>
device_
;
if
(
device
)
{
device_
=
JUST
(
device
);
}
else
{
device_
=
JUST
(
Device
::
New
(
"cpu"
));
}
return
functional
::
Empty
(
shape
,
DType
::
Float
(),
device_
,
/*pin_memory=*/
false
);
}
};
class
ConsistentTensorWithShapeCtorFunctor
{
public:
Maybe
<
Tensor
>
operator
()(
const
Shape
&
shape
,
const
Symbol
<
ParallelDesc
>&
placement
,
const
std
::
vector
<
Symbol
<
SbpParallel
>>&
sbp_tuple
)
const
{
// NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.
LazyMode
::
Guard
lazy_mode_disabled_guard
(
/*is_enabled*/
false
);
JUST
(
CheckDeviceIdsIsValid
(
placement
));
return
functional
::
ConsistentEmpty
(
shape
,
DType
::
Float
(),
placement
,
sbp_tuple
);
}
};
class
AssignLocalTensorFunctor
{
public:
AssignLocalTensorFunctor
()
{
op_
=
CHECK_JUST
(
one
::
OpBuilder
(
"assign"
).
Input
(
"ref"
).
Input
(
"value"
).
Build
());
}
Maybe
<
void
>
operator
()(
const
std
::
shared_ptr
<
one
::
Tensor
>&
ref
,
const
std
::
shared_ptr
<
one
::
Tensor
>&
value
)
const
{
// JUST(CheckInplaceValid(ref)); // align check to torch
CHECK_OR_RETURN
(
ref
->
is_local
()
&&
value
->
is_local
())
<<
"Both ref and value must be local tensor."
;
JUST
(
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op_
,
{
ref
,
value
}));
return
Maybe
<
void
>::
Ok
();
}
private:
std
::
shared_ptr
<
OpExpr
>
op_
;
};
class
LocalTensorSharedNumpyDataFunctor
{
public:
LocalTensorSharedNumpyDataFunctor
()
{}
Maybe
<
Tensor
>
operator
()(
PyObject
*
obj
)
const
{
if
(
!
PyArray_Check
(
obj
))
{
return
Error
::
TypeError
()
<<
"expected np.ndarray, but got "
<<
Py_TYPE
(
obj
)
->
tp_name
;
}
auto
*
array
=
reinterpret_cast
<
PyArrayObject
*>
(
obj
);
// TODO(wyg): support non-contiguous array.
if
(
!
PyArray_IS_C_CONTIGUOUS
(
array
))
{
OF_LOG_ONCE
(
LOG
(
WARNING
)
<<
"OneFlow don't support non-contiguous array now, "
"and we will copy the array to a contiguous one."
);
// PyArray_GETCONTIGUOUS will return a reference if array is already contiguous,
// otherwise return a (contiguous) copy of the array.
// Note: Increment the reference count for array occurs whether the array is continuous or not
array
=
PyArray_GETCONTIGUOUS
(
array
);
}
else
{
Py_INCREF
(
obj
);
}
// Build TensorMeta
int32_t
dim
=
PyArray_NDIM
(
array
);
const
npy_intp
*
dims_ptr
=
PyArray_SHAPE
(
array
);
const
auto
shape
=
std
::
make_shared
<
Shape
>
(
DimVector
(
dims_ptr
,
dims_ptr
+
dim
));
DataType
data_type
=
JUST
(
numpy
::
GetOFDataTypeFromNpArray
(
array
));
Symbol
<
Device
>
device
=
JUST
(
Device
::
New
(
"cpu"
));
const
npy_intp
*
stride_ptr
=
PyArray_STRIDES
(
array
);
// stride
auto
strides
=
std
::
make_shared
<
Stride
>
(
stride_ptr
,
stride_ptr
+
dim
);
auto
element_size_in_bytes
=
PyArray_ITEMSIZE
(
array
);
// NumPy strides use bytes. OneFlow strides use element counts.
for
(
auto
&
stride_val
:
*
strides
)
{
if
(
stride_val
%
element_size_in_bytes
!=
0
)
{
return
Error
::
RuntimeError
()
<<
"given numpy array strides not a multiple of the element "
"byte size. Copy the numpy array to reallocate the memory."
;
}
stride_val
/=
element_size_in_bytes
;
}
auto
tensor_meta
=
std
::
make_shared
<
MirroredTensorMeta
>
(
shape
,
strides
,
data_type
,
device
,
0
);
// Build TensorBuffer
const
auto
&
Free
=
[
array
](
char
*
dptr
)
{
CHECK_JUST
(
Singleton
<
ForeignLockHelper
>::
Get
()
->
WithScopedAcquire
([
&
]()
->
Maybe
<
void
>
{
Py_DECREF
(
array
);
return
Maybe
<
void
>::
Ok
();
}));
};
void
*
data_ptr
=
PyArray_DATA
(
array
);
auto
array_size_in_bytes
=
PyArray_NBYTES
(
array
);
auto
tensor_data
=
std
::
make_shared
<
vm
::
TensorStorage
>
();
tensor_data
->
set_blob_dptr
(
std
::
unique_ptr
<
char
,
std
::
function
<
void
(
char
*
)
>>
(
static_cast
<
char
*>
(
data_ptr
),
Free
),
array_size_in_bytes
);
// Build TensorStorage: decrease ndarray reference count before releasing
auto
tensor_storage
=
std
::
make_shared
<
TensorStorage
>
(
tensor_data
);
// Build Tensor
auto
tensor_impl
=
std
::
make_shared
<
EagerMirroredTensorImpl
>
(
tensor_meta
,
tensor_storage
,
/*requires_grad=*/
false
,
/*ls_leaf=*/
true
);
// Init blob
JUST
(
tensor_impl
->
InitEagerBlobObject
(
NewLocalDepObject
()));
const
auto
&
stream
=
JUST
(
GetDefaultStreamByDevice
(
device
));
const
auto
&
eager_blob_object
=
JUST
(
tensor_impl
->
eager_blob_object
());
JUST
(
eager_blob_object
->
init_producer_stream
(
stream
));
eager_blob_object
->
set_last_used_stream
(
stream
);
std
::
shared_ptr
<
Tensor
>
out
(
new
MirroredTensor
(
tensor_impl
));
return
out
;
}
};
}
// namespace impl
ONEFLOW_FUNCTION_LIBRARY
(
m
)
{
m
.
add_functor
<
impl
::
TensorWithDataFunctor
>
(
"TensorWithData"
);
m
.
add_functor
<
impl
::
ConsistentTensorWithDataFunctor
>
(
"ConsistentTensorWithData"
);
m
.
add_functor
<
impl
::
TensorEmptyCtorFunctor
>
(
"TensorEmptyCtor"
);
m
.
add_functor
<
impl
::
ConsistentTensorEmptyCtorFunctor
>
(
"ConsistentTensorEmptyCtor"
);
m
.
add_functor
<
impl
::
TensorWithOtherCtorFunctor
>
(
"TensorWithOtherCtor"
);
m
.
add_functor
<
impl
::
TensorWithDataCtorFunctor
>
(
"TensorWithDataCtor"
);
m
.
add_functor
<
impl
::
ConsistentTensorWithDataCtorFunctor
>
(
"ConsistentTensorWithDataCtor"
);
m
.
add_functor
<
impl
::
TensorWithShapeCtorFunctor
>
(
"TensorWithShapeCtor"
);
m
.
add_functor
<
impl
::
ConsistentTensorWithShapeCtorFunctor
>
(
"ConsistentTensorWithShapeCtor"
);
m
.
add_functor
<
impl
::
AssignLocalTensorFunctor
>
(
"AssignLocalTensor"
);
m
.
add_functor
<
impl
::
LocalTensorSharedNumpyDataFunctor
>
(
"LocalTensorSharedNumpyData"
);
}
}
// namespace functional
}
// namespace one
}
// namespace oneflow
oneflow/api/python/functional/tensor_api.yaml
0 → 100644
View file @
21d47d0e
# Copyright 2020 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
name
:
"
tensor"
signature
:
[
"
Tensor
(PyObject*
data,
*,
DataType
dtype=None,
Device
device=None,
Bool
requires_grad=False,
Bool
pin_memory=False)
=>
TensorWithData"
,
"
Tensor
(PyObject*
data,
*,
DataType
dtype=None,
Placement
placement,
SbpList
sbp,
Bool
requires_grad=False)
=>
ConsistentTensorWithData"
,
]
bind_python
:
True
-
name
:
"
_legacy_tensor_ctor"
signature
:
[
"
Tensor
(*,
Device
device=None)
=>
TensorEmptyCtor"
,
"
Tensor
(*,
Placement
placement,
SbpList
sbp)
=>
ConsistentTensorEmptyCtor"
,
"
Tensor
(Tensor
other)
=>
TensorWithOtherCtor"
,
"
Tensor
(PyObject*
data,
*,
Device
device=None)
=>
TensorWithDataCtor"
,
"
Tensor
(PyObject*
data,
*,
Placement
placement,
SbpList
sbp)
=>
ConsistentTensorWithDataCtor"
,
"
Tensor
(Shape
size,
*,
Device
device=None)
=>
TensorWithShapeCtor"
,
"
Tensor
(Shape
size,
*,
Placement
placement,
SbpList
sbp)
=>
ConsistentTensorWithShapeCtor"
,
]
bind_python
:
True
-
name
:
"
assign_local_tensor"
signature
:
"
Void
(Tensor
ref,
Tensor
value)
=>
AssignLocalTensor"
bind_python
:
True
-
name
:
"
from_numpy"
signature
:
"
Tensor
(PyObject*
obj)
=>
LocalTensorSharedNumpyData"
bind_python
:
True
oneflow/api/python/functional/value_types.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/functional/value_types.h"
#include "oneflow/core/common/throw.h"
#include "oneflow/core/common/hash_container.h"
namespace
oneflow
{
namespace
one
{
namespace
functional
{
HashMap
<
ValueType
,
std
::
string
>*
GetValueTypeNameMap
()
{
static
HashMap
<
ValueType
,
std
::
string
>
value_type_name_map
=
{
{
kVOID
,
"void"
},
{
kINT32
,
"int32"
},
{
kUINT32
,
"unsigned int32"
},
{
kINT64
,
"int64"
},
{
kUINT64
,
"unsigned int64"
},
{
kFLOAT
,
"float"
},
{
kDOUBLE
,
"double"
},
{
kBOOL
,
"bool"
},
{
kSTRING
,
"string"
},
{
kINT32_LIST
,
"int32 list"
},
{
kUINT32_LIST
,
"unsigned int32 list"
},
{
kINT64_LIST
,
"int64 list"
},
{
kUINT64_LIST
,
"unsigned int64 list"
},
{
kFLOAT_LIST
,
"float list"
},
{
kDOUBLE_LIST
,
"double list"
},
{
kDOUBLE_LIST
,
"bool list"
},
{
kSTRING_LIST
,
"string list"
},
{
kVOID_MAYBE
,
"maybe void"
},
{
kBOOL_MAYBE
,
"maybe bool"
},
{
kSCALAR
,
"scalar"
},
{
kTENSOR
,
"tensor"
},
{
kTENSOR_REF
,
"tensor"
},
{
kTENSOR_MAYBE
,
"maybe tensor"
},
{
kTENSOR_TUPLE
,
"tensor tuple"
},
{
kTENSOR_TUPLE_REF
,
"tensor tuple"
},
{
kTENSOR_TUPLE_MAYBE
,
"maybe tensor tuple"
},
{
kATTR
,
"attr"
},
{
kATTR_REF
,
"attr"
},
{
kDTYPE
,
"data type"
},
{
kDTYPE_LIST
,
"data type list"
},
{
kSHAPE
,
"shape"
},
{
kSHAPE_LIST
,
"shape list"
},
{
kGENERATOR
,
"generator"
},
{
kGENERATOR_REF
,
"generator"
},
{
kGENERATOR_MAYBE
,
"maybe generator"
},
{
kTENSOR_INDEX
,
"index"
},
{
kDEVICE
,
"device"
},
{
kPARALLEL_DESC
,
"placement"
},
{
kSBP_PARALLEL
,
"sbp"
},
{
kSBP_PARALLEL_LIST
,
"sbp list"
},
{
kOPEXPR
,
"opexpr"
},
{
kOPEXPR_REF
,
"opexpr"
},
{
kPY_OBJECT
,
"python object"
},
};
return
&
value_type_name_map
;
}
const
std
::
string
&
ValueTypeName
(
ValueType
type
)
{
const
auto
*
type_name_map
=
GetValueTypeNameMap
();
const
auto
&
it
=
type_name_map
->
find
(
type
);
CHECK_OR_THROW
(
it
!=
type_name_map
->
end
())
<<
"Value type "
<<
type
<<
" has no type name."
;
return
it
->
second
;
}
bool
IsIntegralType
(
ValueType
type
)
{
return
type
>=
kINT32
&&
type
<
kINTEGRAL_MASK
;
}
bool
IsIntegralListType
(
ValueType
type
)
{
return
type
>=
kINT32_LIST
&&
type
<
kINTEGRAL_LIST_MASK
;
}
bool
IsFloatingType
(
ValueType
type
)
{
return
type
>=
kFLOAT
&&
type
<
kFLOATING_MASK
;
}
bool
IsFloatingListType
(
ValueType
type
)
{
return
type
>=
kFLOAT_LIST
&&
type
<
kFLOATING_LIST_MASK
;
}
}
// namespace functional
}
// namespace one
}
// namespace oneflow
oneflow/api/python/functional/value_types.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_FUNCTIONAL_VALUE_TYPES_H_
#define ONEFLOW_CORE_FUNCTIONAL_VALUE_TYPES_H_
#include <memory>
#include <Python.h>
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/framework/dtype.h"
namespace
oneflow
{
class
Scalar
;
class
Shape
;
template
<
typename
T
>
class
Symbol
;
class
Device
;
class
ParallelDesc
;
class
SbpParallel
;
namespace
one
{
class
Tensor
;
class
TensorTuple
;
class
Generator
;
class
OpExpr
;
namespace
functional
{
class
TensorIndex
;
}
// namespace functional
}
// namespace one
namespace
one
{
namespace
functional
{
enum
ValueType
:
int
{
kINVALID
=
0
,
kVOID
,
// Integral
kINT32
,
kINT64
,
kUINT32
,
kUINT64
,
kINTEGRAL_MASK
=
10
,
// Floating
kFLOAT
,
kDOUBLE
,
kFLOATING_MASK
=
15
,
kBOOL
,
kSTRING
,
// Integral list
kINT32_LIST
=
50
,
kUINT32_LIST
,
kINT64_LIST
,
kUINT64_LIST
,
kINTEGRAL_LIST_MASK
=
60
,
// Floating list
kFLOAT_LIST
,
kDOUBLE_LIST
,
kFLOATING_LIST_MASK
=
65
,
kBOOL_LIST
,
kSTRING_LIST
,
kVOID_MAYBE
=
100
,
kBOOL_MAYBE
,
kSCALAR
=
200
,
kTENSOR
,
kTENSOR_REF
,
kTENSOR_MAYBE
,
kTENSOR_TUPLE
,
kTENSOR_TUPLE_REF
,
kTENSOR_TUPLE_MAYBE
,
kATTR
,
kATTR_REF
,
kDTYPE
,
kSHAPE
,
kGENERATOR
,
kGENERATOR_REF
,
kGENERATOR_MAYBE
,
kTENSOR_INDEX
,
kDEVICE
,
kPARALLEL_DESC
,
kSBP_PARALLEL
,
kSBP_PARALLEL_LIST
,
kSHAPE_LIST
,
kDTYPE_LIST
,
kOPEXPR
=
390
,
kOPEXPR_REF
,
kPY_OBJECT
=
400
,
};
#define VALUE_TYPE_OF_IMPL(cpp_type, value_type) \
template<typename T, typename std::enable_if<std::is_same<T, cpp_type>::value, int>::type = 0> \
inline ValueType ValueTypeOf() { \
return value_type; \
} \
template<typename T, \
typename std::enable_if<std::is_same<T, Optional<cpp_type>>::value, int>::type = 0> \
inline ValueType ValueTypeOf() { \
return value_type; \
}
VALUE_TYPE_OF_IMPL
(
void
,
kVOID
);
VALUE_TYPE_OF_IMPL
(
int32_t
,
kINT32
);
VALUE_TYPE_OF_IMPL
(
uint32_t
,
kUINT32
);
VALUE_TYPE_OF_IMPL
(
int64_t
,
kINT64
);
VALUE_TYPE_OF_IMPL
(
uint64_t
,
kUINT64
);
VALUE_TYPE_OF_IMPL
(
float
,
kFLOAT
);
VALUE_TYPE_OF_IMPL
(
double
,
kDOUBLE
);
VALUE_TYPE_OF_IMPL
(
bool
,
kBOOL
);
VALUE_TYPE_OF_IMPL
(
std
::
string
,
kSTRING
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
int32_t
>
,
kINT32_LIST
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
uint32_t
>
,
kUINT32_LIST
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
int64_t
>
,
kINT64_LIST
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
uint64_t
>
,
kUINT64_LIST
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
float
>
,
kFLOAT_LIST
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
double
>
,
kDOUBLE_LIST
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
bool
>
,
kBOOL_LIST
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
std
::
string
>
,
kSTRING_LIST
);
VALUE_TYPE_OF_IMPL
(
Maybe
<
void
>
,
kVOID_MAYBE
);
VALUE_TYPE_OF_IMPL
(
Maybe
<
bool
>
,
kBOOL_MAYBE
);
VALUE_TYPE_OF_IMPL
(
Scalar
,
kSCALAR
);
VALUE_TYPE_OF_IMPL
(
one
::
Tensor
,
kTENSOR
);
VALUE_TYPE_OF_IMPL
(
std
::
shared_ptr
<
one
::
Tensor
>
,
kTENSOR_REF
);
VALUE_TYPE_OF_IMPL
(
Maybe
<
one
::
Tensor
>
,
kTENSOR_MAYBE
);
VALUE_TYPE_OF_IMPL
(
one
::
TensorTuple
,
kTENSOR_TUPLE
);
VALUE_TYPE_OF_IMPL
(
std
::
shared_ptr
<
one
::
TensorTuple
>
,
kTENSOR_TUPLE_REF
);
VALUE_TYPE_OF_IMPL
(
Maybe
<
one
::
TensorTuple
>
,
kTENSOR_TUPLE_MAYBE
);
VALUE_TYPE_OF_IMPL
(
Symbol
<
DType
>
,
kDTYPE
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
Symbol
<
DType
>>
,
kDTYPE_LIST
);
VALUE_TYPE_OF_IMPL
(
Shape
,
kSHAPE
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
Shape
>
,
kSHAPE_LIST
);
VALUE_TYPE_OF_IMPL
(
one
::
Generator
,
kGENERATOR
);
VALUE_TYPE_OF_IMPL
(
std
::
shared_ptr
<
one
::
Generator
>
,
kGENERATOR_REF
);
VALUE_TYPE_OF_IMPL
(
Maybe
<
one
::
Generator
>
,
kGENERATOR_MAYBE
);
VALUE_TYPE_OF_IMPL
(
TensorIndex
,
kTENSOR_INDEX
);
VALUE_TYPE_OF_IMPL
(
Symbol
<
Device
>
,
kDEVICE
);
VALUE_TYPE_OF_IMPL
(
Symbol
<
ParallelDesc
>
,
kPARALLEL_DESC
);
VALUE_TYPE_OF_IMPL
(
Symbol
<
SbpParallel
>
,
kSBP_PARALLEL
);
VALUE_TYPE_OF_IMPL
(
std
::
vector
<
Symbol
<
SbpParallel
>>
,
kSBP_PARALLEL_LIST
);
VALUE_TYPE_OF_IMPL
(
one
::
OpExpr
,
kOPEXPR
);
VALUE_TYPE_OF_IMPL
(
std
::
shared_ptr
<
one
::
OpExpr
>
,
kOPEXPR_REF
);
VALUE_TYPE_OF_IMPL
(
PyObject
*
,
kPY_OBJECT
);
VALUE_TYPE_OF_IMPL
(
const
PyObject
*
,
kPY_OBJECT
);
#undef VALUE_TYPE_OF_IMPL
const
std
::
string
&
ValueTypeName
(
ValueType
type
);
bool
IsIntegralType
(
ValueType
type
);
bool
IsIntegralListType
(
ValueType
type
);
bool
IsFloatingType
(
ValueType
type
);
bool
IsFloatingListType
(
ValueType
type
);
}
// namespace functional
}
// namespace one
}
// namespace oneflow
namespace
std
{
template
<
>
struct
hash
<
oneflow
::
one
::
functional
::
ValueType
>
{
std
::
size_t
operator
()(
oneflow
::
one
::
functional
::
ValueType
v
)
const
noexcept
{
return
v
;
}
};
}
// namespace std
#endif // ONEFLOW_CORE_FUNCTIONAL_VALUE_TYPES_H_
Prev
1
…
10
11
12
13
14
15
16
17
18
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment