Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
21d47d0e
Commit
21d47d0e
authored
Oct 24, 2022
by
yuguo
Browse files
Oneflow 0.8 for DCU
parents
Changes
556
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2445 additions
and
0 deletions
+2445
-0
oneflow/core/common/buffer_manager.h
oneflow/core/common/buffer_manager.h
+100
-0
oneflow/core/common/cached_caller.cpp
oneflow/core/common/cached_caller.cpp
+28
-0
oneflow/core/common/cached_caller.h
oneflow/core/common/cached_caller.h
+98
-0
oneflow/core/common/cblas.h
oneflow/core/common/cblas.h
+491
-0
oneflow/core/common/channel.h
oneflow/core/common/channel.h
+90
-0
oneflow/core/common/channel_test.cpp
oneflow/core/common/channel_test.cpp
+63
-0
oneflow/core/common/check_level.cpp
oneflow/core/common/check_level.cpp
+30
-0
oneflow/core/common/check_level.h
oneflow/core/common/check_level.h
+25
-0
oneflow/core/common/constant.h
oneflow/core/common/constant.h
+33
-0
oneflow/core/common/container_util.h
oneflow/core/common/container_util.h
+99
-0
oneflow/core/common/container_util_test.cpp
oneflow/core/common/container_util_test.cpp
+54
-0
oneflow/core/common/cplusplus_17.h
oneflow/core/common/cplusplus_17.h
+80
-0
oneflow/core/common/cplusplus_17_test.cpp
oneflow/core/common/cplusplus_17_test.cpp
+63
-0
oneflow/core/common/cpp_attribute.h
oneflow/core/common/cpp_attribute.h
+24
-0
oneflow/core/common/data_type.cpp
oneflow/core/common/data_type.cpp
+188
-0
oneflow/core/common/data_type.h
oneflow/core/common/data_type.h
+279
-0
oneflow/core/common/data_type.proto
oneflow/core/common/data_type.proto
+31
-0
oneflow/core/common/data_type_converter.h
oneflow/core/common/data_type_converter.h
+406
-0
oneflow/core/common/data_type_converter_test.cpp
oneflow/core/common/data_type_converter_test.cpp
+201
-0
oneflow/core/common/data_type_converter_test_static.h
oneflow/core/common/data_type_converter_test_static.h
+62
-0
No files found.
Too many changes to show.
To preserve performance only
556 of 556+
files are displayed.
Plain diff
Email patch
oneflow/core/common/buffer_manager.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_
#define ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/buffer.h"
namespace
oneflow
{
template
<
typename
T
>
class
BufferMgr
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BufferMgr
);
~
BufferMgr
()
=
default
;
void
NewBuffer
(
const
std
::
string
&
buffer_name
,
size_t
buffer_size
)
{
CHECK
(
name2buffer_
.
emplace
(
buffer_name
,
std
::
make_unique
<
Buffer
<
T
>>
(
buffer_size
)).
second
);
}
Buffer
<
T
>*
Get
(
const
std
::
string
&
buffer_name
)
const
{
const
auto
&
iter
=
name2buffer_
.
find
(
buffer_name
);
CHECK
(
iter
!=
name2buffer_
.
end
())
<<
"buffer_name: "
<<
buffer_name
;
return
iter
->
second
.
get
();
}
private:
friend
class
Singleton
<
BufferMgr
>
;
BufferMgr
()
=
default
;
HashMap
<
std
::
string
,
std
::
unique_ptr
<
Buffer
<
T
>>>
name2buffer_
;
};
static
const
std
::
string
kBufferNameGlobalWaitJobId
=
"GlobalWaitJobId"
;
inline
std
::
string
GetCallbackNotifierBufferName
(
const
std
::
string
&
job_name
)
{
static
const
std
::
string
prefix
=
"CallbackNotifier-"
;
return
prefix
+
job_name
;
}
inline
std
::
string
GetInputCriticalSectionWaitBufferName
(
const
std
::
string
&
job_name
)
{
static
const
std
::
string
prefix
=
"InputCriticalSectionWait-"
;
return
prefix
+
job_name
;
}
inline
std
::
string
GetInputCriticalSectionCallbackBufferName
(
const
std
::
string
&
job_name
)
{
static
const
std
::
string
prefix
=
"InputCriticalSectionCallback-"
;
return
prefix
+
job_name
;
}
inline
std
::
string
GetOutputCriticalSectionWaitBufferName
(
const
std
::
string
&
job_name
)
{
static
const
std
::
string
prefix
=
"OutputCriticalSectionWait-"
;
return
prefix
+
job_name
;
}
inline
std
::
string
GetOutputCriticalSectionCallbackBufferName
(
const
std
::
string
&
job_name
)
{
static
const
std
::
string
prefix
=
"OutputCriticalSectionCallback-"
;
return
prefix
+
job_name
;
}
inline
std
::
string
GetForeignInputBufferName
(
const
std
::
string
&
job_name
)
{
static
const
std
::
string
prefix
=
"ForeignInput-"
;
return
prefix
+
job_name
;
}
inline
std
::
string
GetForeignOutputBufferName
(
const
std
::
string
&
job_name
)
{
static
const
std
::
string
prefix
=
"ForeignOutput-"
;
return
prefix
+
job_name
;
}
inline
std
::
string
GetInputBufferName
(
const
std
::
string
&
job_name
,
const
std
::
string
&
op_name
)
{
static
const
std
::
string
prefix
=
"ForeignInput-"
;
return
prefix
+
job_name
+
"-"
+
op_name
;
}
inline
std
::
string
GetOutputBufferName
(
const
std
::
string
&
job_name
,
const
std
::
string
&
op_name
)
{
static
const
std
::
string
prefix
=
"ForeignOutput-"
;
return
prefix
+
job_name
+
"-"
+
op_name
;
}
inline
std
::
string
GetSourceTickBufferName
(
const
std
::
string
&
job_name
)
{
static
const
std
::
string
prefix
=
"SourceTick-"
;
return
prefix
+
job_name
;
}
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_
oneflow/core/common/cached_caller.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/cached_caller.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
namespace
oneflow
{
bool
IsThreadLocalCacheEnabled
()
{
if
(
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
==
nullptr
)
{
return
true
;
}
return
Singleton
<
ResourceDesc
,
ForSession
>::
Get
()
->
enable_thread_local_cache
();
}
}
// namespace oneflow
oneflow/core/common/cached_caller.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CACHED_CALLER_H_
#define ONEFLOW_CORE_COMMON_CACHED_CALLER_H_
#include <list>
#include <tuple>
#include <thread>
#include "oneflow/core/common/function_traits.h"
#include "oneflow/core/common/hash_eq_trait_ptr.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/tuple_hash.h"
// gcc 11 falsely reports error:
// ‘void operator delete(void*, std::size_t)’ called on unallocated object ‘cache’
// However, `DeleteAndClear` is only called after `cache` is allocated in
// if (cache == nullptr) block.
// The reason not to use #pragma GCC diagnostic push/pop is that gcc reports
// the error on the caller of `ThreadLocalCachedCall`.
// TODO: replace ThreadLocalCachedCall with ThreadLocalCached decorator?
#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 11
#pragma GCC diagnostic ignored "-Wfree-nonheap-object"
#endif
namespace
oneflow
{
template
<
typename
T
>
void
DeleteAndClear
(
T
**
ptr
,
size_t
obj_cnt
)
{
static
const
size_t
kThreshold
=
4096
;
if
(
obj_cnt
<=
kThreshold
)
{
delete
ptr
;
}
else
{
std
::
thread
([](
T
*
ptr
)
{
delete
ptr
;
},
*
ptr
);
}
*
ptr
=
nullptr
;
}
bool
IsThreadLocalCacheEnabled
();
template
<
typename
F
,
typename
Ret
=
typename
function_traits
<
F
>
::
return_type
,
typename
RawArg
=
typename
std
::
tuple_element
<
0
,
typename
function_traits
<
F
>::
args_type
>::
type
,
typename
Arg
=
typename
std
::
remove_const
<
typename
std
::
remove_reference
<
RawArg
>::
type
>::
type
>
Ret
ThreadLocalCachedCall
(
size_t
max_size
,
F
f
,
const
Arg
&
arg
)
{
if
(
IsThreadLocalCacheEnabled
()
==
false
)
{
return
f
(
arg
);
}
using
HashMap
=
std
::
unordered_map
<
HashEqTraitPtr
<
const
Arg
>
,
Ret
>
;
using
KeyStorage
=
std
::
list
<
std
::
unique_ptr
<
Arg
>>
;
static
thread_local
HashMap
*
cache
=
nullptr
;
static
thread_local
KeyStorage
*
key_storage
=
nullptr
;
if
(
cache
!=
nullptr
&&
cache
->
size
()
>=
max_size
)
{
DeleteAndClear
(
&
cache
,
cache
->
size
());
DeleteAndClear
(
&
key_storage
,
cache
->
size
());
}
if
(
cache
==
nullptr
)
{
cache
=
new
HashMap
();
key_storage
=
new
KeyStorage
();
}
size_t
hash_value
=
std
::
hash
<
Arg
>
()(
arg
);
{
HashEqTraitPtr
<
const
Arg
>
ptr_wrapper
(
&
arg
,
hash_value
);
const
auto
&
iter
=
cache
->
find
(
ptr_wrapper
);
if
(
iter
!=
cache
->
end
())
{
return
iter
->
second
;
}
}
Arg
*
new_arg
=
new
Arg
(
arg
);
key_storage
->
emplace_back
(
new_arg
);
HashEqTraitPtr
<
const
Arg
>
ptr_wrapper
(
new_arg
,
hash_value
);
return
cache
->
emplace
(
ptr_wrapper
,
f
(
*
new_arg
)).
first
->
second
;
}
template
<
typename
F
,
typename
Ret
=
typename
function_traits
<
F
>
::
return_type
,
typename
RawArg
=
typename
std
::
tuple_element
<
0
,
typename
function_traits
<
F
>::
args_type
>::
type
,
typename
Arg
=
typename
std
::
remove_const
<
typename
std
::
remove_reference
<
RawArg
>::
type
>::
type
>
std
::
function
<
Ret
(
const
Arg
&
)
>
WithResultCached
(
F
f
)
{
auto
cache
=
std
::
make_shared
<
std
::
unordered_map
<
Arg
,
Ret
>>
();
return
[
cache
,
f
](
const
Arg
&
arg
)
->
Ret
{
const
auto
&
iter
=
cache
->
find
(
arg
);
if
(
iter
!=
cache
->
end
())
{
return
iter
->
second
;
}
return
cache
->
emplace
(
arg
,
f
(
arg
)).
first
->
second
;
};
}
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CACHED_CALLER_H_
oneflow/core/common/cblas.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CBLAS_H_
#define ONEFLOW_CORE_COMMON_CBLAS_H_
#include <stddef.h>
/*
* Enumerated and derived types
*/
#define CBLAS_INDEX size_t
/* this may vary between platforms */
enum
CBLAS_ORDER
{
CblasRowMajor
=
101
,
CblasColMajor
=
102
};
enum
CBLAS_TRANSPOSE
{
CblasNoTrans
=
111
,
CblasTrans
=
112
,
CblasConjTrans
=
113
};
enum
CBLAS_UPLO
{
CblasUpper
=
121
,
CblasLower
=
122
};
enum
CBLAS_DIAG
{
CblasNonUnit
=
131
,
CblasUnit
=
132
};
enum
CBLAS_SIDE
{
CblasLeft
=
141
,
CblasRight
=
142
};
#ifdef __cplusplus
extern
"C"
{
#endif
/*
* ===========================================================================
* Prototypes for level 1 BLAS functions (complex are recast as routines)
* ===========================================================================
*/
float
cblas_sdsdot
(
const
int
N
,
const
float
alpha
,
const
float
*
X
,
const
int
incX
,
const
float
*
Y
,
const
int
incY
);
double
cblas_dsdot
(
const
int
N
,
const
float
*
X
,
const
int
incX
,
const
float
*
Y
,
const
int
incY
);
float
cblas_sdot
(
const
int
N
,
const
float
*
X
,
const
int
incX
,
const
float
*
Y
,
const
int
incY
);
double
cblas_ddot
(
const
int
N
,
const
double
*
X
,
const
int
incX
,
const
double
*
Y
,
const
int
incY
);
/*
* Functions having prefixes Z and C only
*/
void
cblas_cdotu_sub
(
const
int
N
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
dotu
);
void
cblas_cdotc_sub
(
const
int
N
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
dotc
);
void
cblas_zdotu_sub
(
const
int
N
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
dotu
);
void
cblas_zdotc_sub
(
const
int
N
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
dotc
);
/*
* Functions having prefixes S D SC DZ
*/
float
cblas_snrm2
(
const
int
N
,
const
float
*
X
,
const
int
incX
);
float
cblas_sasum
(
const
int
N
,
const
float
*
X
,
const
int
incX
);
double
cblas_dnrm2
(
const
int
N
,
const
double
*
X
,
const
int
incX
);
double
cblas_dasum
(
const
int
N
,
const
double
*
X
,
const
int
incX
);
float
cblas_scnrm2
(
const
int
N
,
const
void
*
X
,
const
int
incX
);
float
cblas_scasum
(
const
int
N
,
const
void
*
X
,
const
int
incX
);
double
cblas_dznrm2
(
const
int
N
,
const
void
*
X
,
const
int
incX
);
double
cblas_dzasum
(
const
int
N
,
const
void
*
X
,
const
int
incX
);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX
cblas_isamax
(
const
int
N
,
const
float
*
X
,
const
int
incX
);
CBLAS_INDEX
cblas_idamax
(
const
int
N
,
const
double
*
X
,
const
int
incX
);
CBLAS_INDEX
cblas_icamax
(
const
int
N
,
const
void
*
X
,
const
int
incX
);
CBLAS_INDEX
cblas_izamax
(
const
int
N
,
const
void
*
X
,
const
int
incX
);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void
cblas_sswap
(
const
int
N
,
float
*
X
,
const
int
incX
,
float
*
Y
,
const
int
incY
);
void
cblas_scopy
(
const
int
N
,
const
float
*
X
,
const
int
incX
,
float
*
Y
,
const
int
incY
);
void
cblas_saxpy
(
const
int
N
,
const
float
alpha
,
const
float
*
X
,
const
int
incX
,
float
*
Y
,
const
int
incY
);
void
cblas_dswap
(
const
int
N
,
double
*
X
,
const
int
incX
,
double
*
Y
,
const
int
incY
);
void
cblas_dcopy
(
const
int
N
,
const
double
*
X
,
const
int
incX
,
double
*
Y
,
const
int
incY
);
void
cblas_daxpy
(
const
int
N
,
const
double
alpha
,
const
double
*
X
,
const
int
incX
,
double
*
Y
,
const
int
incY
);
void
cblas_cswap
(
const
int
N
,
void
*
X
,
const
int
incX
,
void
*
Y
,
const
int
incY
);
void
cblas_ccopy
(
const
int
N
,
const
void
*
X
,
const
int
incX
,
void
*
Y
,
const
int
incY
);
void
cblas_caxpy
(
const
int
N
,
const
void
*
alpha
,
const
void
*
X
,
const
int
incX
,
void
*
Y
,
const
int
incY
);
void
cblas_zswap
(
const
int
N
,
void
*
X
,
const
int
incX
,
void
*
Y
,
const
int
incY
);
void
cblas_zcopy
(
const
int
N
,
const
void
*
X
,
const
int
incX
,
void
*
Y
,
const
int
incY
);
void
cblas_zaxpy
(
const
int
N
,
const
void
*
alpha
,
const
void
*
X
,
const
int
incX
,
void
*
Y
,
const
int
incY
);
/*
* Routines with S and D prefix only
*/
void
cblas_srotg
(
float
*
a
,
float
*
b
,
float
*
c
,
float
*
s
);
void
cblas_srotmg
(
float
*
d1
,
float
*
d2
,
float
*
b1
,
const
float
b2
,
float
*
P
);
void
cblas_srot
(
const
int
N
,
float
*
X
,
const
int
incX
,
float
*
Y
,
const
int
incY
,
const
float
c
,
const
float
s
);
void
cblas_srotm
(
const
int
N
,
float
*
X
,
const
int
incX
,
float
*
Y
,
const
int
incY
,
const
float
*
P
);
void
cblas_drotg
(
double
*
a
,
double
*
b
,
double
*
c
,
double
*
s
);
void
cblas_drotmg
(
double
*
d1
,
double
*
d2
,
double
*
b1
,
const
double
b2
,
double
*
P
);
void
cblas_drot
(
const
int
N
,
double
*
X
,
const
int
incX
,
double
*
Y
,
const
int
incY
,
const
double
c
,
const
double
s
);
void
cblas_drotm
(
const
int
N
,
double
*
X
,
const
int
incX
,
double
*
Y
,
const
int
incY
,
const
double
*
P
);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void
cblas_sscal
(
const
int
N
,
const
float
alpha
,
float
*
X
,
const
int
incX
);
void
cblas_dscal
(
const
int
N
,
const
double
alpha
,
double
*
X
,
const
int
incX
);
void
cblas_cscal
(
const
int
N
,
const
void
*
alpha
,
void
*
X
,
const
int
incX
);
void
cblas_zscal
(
const
int
N
,
const
void
*
alpha
,
void
*
X
,
const
int
incX
);
void
cblas_csscal
(
const
int
N
,
const
float
alpha
,
void
*
X
,
const
int
incX
);
void
cblas_zdscal
(
const
int
N
,
const
double
alpha
,
void
*
X
,
const
int
incX
);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void
cblas_sgemv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
X
,
const
int
incX
,
const
float
beta
,
float
*
Y
,
const
int
incY
);
void
cblas_sgbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
int
M
,
const
int
N
,
const
int
KL
,
const
int
KU
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
X
,
const
int
incX
,
const
float
beta
,
float
*
Y
,
const
int
incY
);
void
cblas_strmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
float
*
A
,
const
int
lda
,
float
*
X
,
const
int
incX
);
void
cblas_stbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
int
K
,
const
float
*
A
,
const
int
lda
,
float
*
X
,
const
int
incX
);
void
cblas_stpmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
float
*
Ap
,
float
*
X
,
const
int
incX
);
void
cblas_strsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
float
*
A
,
const
int
lda
,
float
*
X
,
const
int
incX
);
void
cblas_stbsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
int
K
,
const
float
*
A
,
const
int
lda
,
float
*
X
,
const
int
incX
);
void
cblas_stpsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
float
*
Ap
,
float
*
X
,
const
int
incX
);
void
cblas_dgemv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
int
M
,
const
int
N
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
X
,
const
int
incX
,
const
double
beta
,
double
*
Y
,
const
int
incY
);
void
cblas_dgbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
int
M
,
const
int
N
,
const
int
KL
,
const
int
KU
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
X
,
const
int
incX
,
const
double
beta
,
double
*
Y
,
const
int
incY
);
void
cblas_dtrmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
double
*
A
,
const
int
lda
,
double
*
X
,
const
int
incX
);
void
cblas_dtbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
int
K
,
const
double
*
A
,
const
int
lda
,
double
*
X
,
const
int
incX
);
void
cblas_dtpmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
double
*
Ap
,
double
*
X
,
const
int
incX
);
void
cblas_dtrsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
double
*
A
,
const
int
lda
,
double
*
X
,
const
int
incX
);
void
cblas_dtbsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
int
K
,
const
double
*
A
,
const
int
lda
,
double
*
X
,
const
int
incX
);
void
cblas_dtpsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
double
*
Ap
,
double
*
X
,
const
int
incX
);
void
cblas_cgemv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
X
,
const
int
incX
,
const
void
*
beta
,
void
*
Y
,
const
int
incY
);
void
cblas_cgbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
int
M
,
const
int
N
,
const
int
KL
,
const
int
KU
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
X
,
const
int
incX
,
const
void
*
beta
,
void
*
Y
,
const
int
incY
);
void
cblas_ctrmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
void
*
A
,
const
int
lda
,
void
*
X
,
const
int
incX
);
void
cblas_ctbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
int
K
,
const
void
*
A
,
const
int
lda
,
void
*
X
,
const
int
incX
);
void
cblas_ctpmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
void
*
Ap
,
void
*
X
,
const
int
incX
);
void
cblas_ctrsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
void
*
A
,
const
int
lda
,
void
*
X
,
const
int
incX
);
void
cblas_ctbsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
int
K
,
const
void
*
A
,
const
int
lda
,
void
*
X
,
const
int
incX
);
void
cblas_ctpsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
void
*
Ap
,
void
*
X
,
const
int
incX
);
void
cblas_zgemv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
X
,
const
int
incX
,
const
void
*
beta
,
void
*
Y
,
const
int
incY
);
void
cblas_zgbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
int
M
,
const
int
N
,
const
int
KL
,
const
int
KU
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
X
,
const
int
incX
,
const
void
*
beta
,
void
*
Y
,
const
int
incY
);
void
cblas_ztrmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
void
*
A
,
const
int
lda
,
void
*
X
,
const
int
incX
);
void
cblas_ztbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
int
K
,
const
void
*
A
,
const
int
lda
,
void
*
X
,
const
int
incX
);
void
cblas_ztpmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
void
*
Ap
,
void
*
X
,
const
int
incX
);
void
cblas_ztrsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
void
*
A
,
const
int
lda
,
void
*
X
,
const
int
incX
);
void
cblas_ztbsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
int
K
,
const
void
*
A
,
const
int
lda
,
void
*
X
,
const
int
incX
);
void
cblas_ztpsv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
N
,
const
void
*
Ap
,
void
*
X
,
const
int
incX
);
/*
* Routines with S and D prefixes only
*/
void
cblas_ssymv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
X
,
const
int
incX
,
const
float
beta
,
float
*
Y
,
const
int
incY
);
void
cblas_ssbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
X
,
const
int
incX
,
const
float
beta
,
float
*
Y
,
const
int
incY
);
void
cblas_sspmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
float
alpha
,
const
float
*
Ap
,
const
float
*
X
,
const
int
incX
,
const
float
beta
,
float
*
Y
,
const
int
incY
);
void
cblas_sger
(
const
enum
CBLAS_ORDER
order
,
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
X
,
const
int
incX
,
const
float
*
Y
,
const
int
incY
,
float
*
A
,
const
int
lda
);
void
cblas_ssyr
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
float
alpha
,
const
float
*
X
,
const
int
incX
,
float
*
A
,
const
int
lda
);
void
cblas_sspr
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
float
alpha
,
const
float
*
X
,
const
int
incX
,
float
*
Ap
);
void
cblas_ssyr2
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
float
alpha
,
const
float
*
X
,
const
int
incX
,
const
float
*
Y
,
const
int
incY
,
float
*
A
,
const
int
lda
);
void
cblas_sspr2
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
float
alpha
,
const
float
*
X
,
const
int
incX
,
const
float
*
Y
,
const
int
incY
,
float
*
A
);
void
cblas_dsymv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
X
,
const
int
incX
,
const
double
beta
,
double
*
Y
,
const
int
incY
);
void
cblas_dsbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
X
,
const
int
incX
,
const
double
beta
,
double
*
Y
,
const
int
incY
);
void
cblas_dspmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
double
alpha
,
const
double
*
Ap
,
const
double
*
X
,
const
int
incX
,
const
double
beta
,
double
*
Y
,
const
int
incY
);
void
cblas_dger
(
const
enum
CBLAS_ORDER
order
,
const
int
M
,
const
int
N
,
const
double
alpha
,
const
double
*
X
,
const
int
incX
,
const
double
*
Y
,
const
int
incY
,
double
*
A
,
const
int
lda
);
void
cblas_dsyr
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
double
alpha
,
const
double
*
X
,
const
int
incX
,
double
*
A
,
const
int
lda
);
void
cblas_dspr
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
double
alpha
,
const
double
*
X
,
const
int
incX
,
double
*
Ap
);
void
cblas_dsyr2
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
double
alpha
,
const
double
*
X
,
const
int
incX
,
const
double
*
Y
,
const
int
incY
,
double
*
A
,
const
int
lda
);
void
cblas_dspr2
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
double
alpha
,
const
double
*
X
,
const
int
incX
,
const
double
*
Y
,
const
int
incY
,
double
*
A
);
/*
* Routines with C and Z prefixes only
*/
void
cblas_chemv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
X
,
const
int
incX
,
const
void
*
beta
,
void
*
Y
,
const
int
incY
);
void
cblas_chbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
int
K
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
X
,
const
int
incX
,
const
void
*
beta
,
void
*
Y
,
const
int
incY
);
void
cblas_chpmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
void
*
alpha
,
const
void
*
Ap
,
const
void
*
X
,
const
int
incX
,
const
void
*
beta
,
void
*
Y
,
const
int
incY
);
void
cblas_cgeru
(
const
enum
CBLAS_ORDER
order
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
A
,
const
int
lda
);
void
cblas_cgerc
(
const
enum
CBLAS_ORDER
order
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
A
,
const
int
lda
);
void
cblas_cher
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
float
alpha
,
const
void
*
X
,
const
int
incX
,
void
*
A
,
const
int
lda
);
void
cblas_chpr
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
float
alpha
,
const
void
*
X
,
const
int
incX
,
void
*
A
);
void
cblas_cher2
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
void
*
alpha
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
A
,
const
int
lda
);
void
cblas_chpr2
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
void
*
alpha
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
Ap
);
void
cblas_zhemv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
X
,
const
int
incX
,
const
void
*
beta
,
void
*
Y
,
const
int
incY
);
void
cblas_zhbmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
int
K
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
X
,
const
int
incX
,
const
void
*
beta
,
void
*
Y
,
const
int
incY
);
void
cblas_zhpmv
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
void
*
alpha
,
const
void
*
Ap
,
const
void
*
X
,
const
int
incX
,
const
void
*
beta
,
void
*
Y
,
const
int
incY
);
void
cblas_zgeru
(
const
enum
CBLAS_ORDER
order
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
A
,
const
int
lda
);
void
cblas_zgerc
(
const
enum
CBLAS_ORDER
order
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
A
,
const
int
lda
);
void
cblas_zher
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
double
alpha
,
const
void
*
X
,
const
int
incX
,
void
*
A
,
const
int
lda
);
void
cblas_zhpr
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
double
alpha
,
const
void
*
X
,
const
int
incX
,
void
*
A
);
void
cblas_zher2
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
void
*
alpha
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
A
,
const
int
lda
);
void
cblas_zhpr2
(
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
N
,
const
void
*
alpha
,
const
void
*
X
,
const
int
incX
,
const
void
*
Y
,
const
int
incY
,
void
*
Ap
);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void
cblas_sgemm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_TRANSPOSE
TransB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
);
void
cblas_ssymm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
);
void
cblas_ssyrk
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
beta
,
float
*
C
,
const
int
ldc
);
void
cblas_ssyr2k
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
);
void
cblas_strmm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
float
*
B
,
const
int
ldb
);
void
cblas_strsm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
float
*
B
,
const
int
ldb
);
void
cblas_dgemm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_TRANSPOSE
TransB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
);
void
cblas_dsymm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
M
,
const
int
N
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
);
void
cblas_dsyrk
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
beta
,
double
*
C
,
const
int
ldc
);
void
cblas_dsyr2k
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
);
void
cblas_dtrmm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
M
,
const
int
N
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
double
*
B
,
const
int
ldb
);
void
cblas_dtrsm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
M
,
const
int
N
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
double
*
B
,
const
int
ldb
);
void
cblas_cgemm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_TRANSPOSE
TransB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
B
,
const
int
ldb
,
const
void
*
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_csymm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
B
,
const
int
ldb
,
const
void
*
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_csyrk
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_csyr2k
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
B
,
const
int
ldb
,
const
void
*
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_ctrmm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
void
*
B
,
const
int
ldb
);
void
cblas_ctrsm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
void
*
B
,
const
int
ldb
);
void
cblas_zgemm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_TRANSPOSE
TransB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
B
,
const
int
ldb
,
const
void
*
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_zsymm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
B
,
const
int
ldb
,
const
void
*
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_zsyrk
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_zsyr2k
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
B
,
const
int
ldb
,
const
void
*
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_ztrmm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
void
*
B
,
const
int
ldb
);
void
cblas_ztrsm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
TransA
,
const
enum
CBLAS_DIAG
Diag
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
void
*
B
,
const
int
ldb
);
/*
* Routines with prefixes C and Z only
*/
void
cblas_chemm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
B
,
const
int
ldb
,
const
void
*
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_cherk
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
void
*
A
,
const
int
lda
,
const
float
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_cher2k
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
B
,
const
int
ldb
,
const
float
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_zhemm
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_SIDE
Side
,
const
enum
CBLAS_UPLO
Uplo
,
const
int
M
,
const
int
N
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
B
,
const
int
ldb
,
const
void
*
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_zherk
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
void
*
A
,
const
int
lda
,
const
double
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_zher2k
(
const
enum
CBLAS_ORDER
Order
,
const
enum
CBLAS_UPLO
Uplo
,
const
enum
CBLAS_TRANSPOSE
Trans
,
const
int
N
,
const
int
K
,
const
void
*
alpha
,
const
void
*
A
,
const
int
lda
,
const
void
*
B
,
const
int
ldb
,
const
double
beta
,
void
*
C
,
const
int
ldc
);
void
cblas_xerbla
(
int
p
,
const
char
*
rout
,
const
char
*
form
,
...);
#ifdef __cplusplus
}
#endif
#endif
oneflow/core/common/channel.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CHANNEL_H_
#define ONEFLOW_CORE_COMMON_CHANNEL_H_
#include "oneflow/core/common/util.h"
namespace
oneflow
{
enum
ChannelStatus
{
kChannelStatusSuccess
=
0
,
kChannelStatusErrorClosed
};
template
<
typename
T
>
class
Channel
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
Channel
);
Channel
()
:
is_closed_
(
false
)
{}
~
Channel
()
=
default
;
template
<
typename
U
>
ChannelStatus
Send
(
U
&&
item
);
ChannelStatus
Receive
(
T
*
item
);
ChannelStatus
ReceiveMany
(
std
::
queue
<
T
>*
items
);
void
Close
();
private:
std
::
queue
<
T
>
queue_
;
std
::
mutex
mutex_
;
bool
is_closed_
;
std
::
condition_variable
cond_
;
};
template
<
typename
T
>
template
<
typename
U
>
ChannelStatus
Channel
<
T
>::
Send
(
U
&&
item
)
{
bool
notify
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
is_closed_
)
{
return
kChannelStatusErrorClosed
;
}
notify
=
queue_
.
empty
();
queue_
.
push
(
std
::
forward
<
U
>
(
item
));
}
if
(
notify
)
{
cond_
.
notify_one
();
}
return
kChannelStatusSuccess
;
}
template
<
typename
T
>
ChannelStatus
Channel
<
T
>::
Receive
(
T
*
item
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cond_
.
wait
(
lock
,
[
this
]()
{
return
(
!
queue_
.
empty
())
||
is_closed_
;
});
if
(
queue_
.
empty
())
{
return
kChannelStatusErrorClosed
;
}
*
item
=
std
::
move
(
queue_
.
front
());
queue_
.
pop
();
return
kChannelStatusSuccess
;
}
template
<
typename
T
>
ChannelStatus
Channel
<
T
>::
ReceiveMany
(
std
::
queue
<
T
>*
items
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cond_
.
wait
(
lock
,
[
this
]()
{
return
(
!
queue_
.
empty
())
||
is_closed_
;
});
if
(
queue_
.
empty
())
{
return
kChannelStatusErrorClosed
;
}
while
(
!
queue_
.
empty
())
{
items
->
push
(
std
::
move
(
queue_
.
front
()));
queue_
.
pop
();
}
return
kChannelStatusSuccess
;
}
template
<
typename
T
>
void
Channel
<
T
>::
Close
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
is_closed_
=
true
;
cond_
.
notify_all
();
}
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CHANNEL_H_
oneflow/core/common/channel_test.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "gtest/gtest.h"
#include "oneflow/core/common/channel.h"
#include "oneflow/core/common/range.h"
namespace
oneflow
{
void
CallFromSenderThread
(
Channel
<
int
>*
channel
,
Range
range
)
{
for
(
int
i
=
range
.
begin
();
i
<
range
.
end
();
++
i
)
{
if
(
channel
->
Send
(
i
)
!=
kChannelStatusSuccess
)
{
break
;
}
}
}
void
CallFromReceiverThread
(
std
::
vector
<
int
>*
visit
,
Channel
<
int
>*
channel
)
{
int
num
=
-
1
;
int
*
num_ptr
=
&
num
;
while
(
channel
->
Receive
(
num_ptr
)
==
kChannelStatusSuccess
)
{
++
visit
->
at
(
*
num_ptr
);
}
}
TEST
(
Channel
,
30s
ender40receiver
)
{
Channel
<
int
>
channel
;
std
::
vector
<
std
::
thread
>
senders
;
std
::
vector
<
std
::
thread
>
receivers
;
int
sender_num
=
30
;
int
receiver_num
=
40
;
int
range_num
=
200
;
std
::
vector
<
std
::
vector
<
int
>>
visits
;
for
(
int
i
=
0
;
i
<
receiver_num
;
++
i
)
{
std
::
vector
<
int
>
visit_i
;
for
(
int
j
=
0
;
j
<
range_num
;
j
++
)
{
visit_i
.
emplace_back
(
0
);
}
visits
.
emplace_back
(
visit_i
);
}
for
(
int
i
=
0
;
i
<
sender_num
;
++
i
)
{
senders
.
emplace_back
(
CallFromSenderThread
,
&
channel
,
Range
(
0
,
range_num
));
}
for
(
int
i
=
0
;
i
<
receiver_num
;
++
i
)
{
receivers
.
emplace_back
(
CallFromReceiverThread
,
&
visits
[
i
],
&
channel
);
}
for
(
std
::
thread
&
this_thread
:
senders
)
{
this_thread
.
join
();
}
channel
.
Close
();
for
(
std
::
thread
&
this_thread
:
receivers
)
{
this_thread
.
join
();
}
for
(
int
i
=
0
;
i
<
range_num
;
++
i
)
{
int
visit_count
=
0
;
for
(
int
j
=
0
;
j
<
receiver_num
;
j
++
)
{
visit_count
+=
visits
[
j
][
i
];
}
ASSERT_EQ
(
visit_count
,
sender_num
);
}
}
}
// namespace oneflow
oneflow/core/common/check_level.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cstdlib>
#include <type_traits>
#include "oneflow/core/common/just.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/env_var/debug_mode.h"
namespace
oneflow
{
bool
IsEnvEnabled
(
int32_t
check_level
)
{
static
const
int
env_check_level
=
ParseIntegerFromEnv
(
"ONEFOW_CHECK_LEVEL"
,
-
1
);
static
const
bool
env_debug_mode
=
IsInDebugMode
();
return
env_debug_mode
||
env_check_level
>=
check_level
;
}
}
// namespace oneflow
oneflow/core/common/check_level.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_
#define ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_
namespace
oneflow
{
bool
IsEnvEnabled
(
int32_t
check_level
);
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_
oneflow/core/common/constant.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CONSTANT_H_
#define ONEFLOW_CORE_COMMON_CONSTANT_H_
#include <string>
namespace
oneflow
{
static
const
int64_t
kInvalidSessionId
=
-
1
;
static
const
std
::
string
kNoPassTag
=
""
;
static
const
std
::
string
kMainOp
=
"main_op"
;
static
const
int64_t
kMaxSplitAxis
=
6
;
static
const
std
::
string
kAsymmetricCodeErrorMsg
=
"Maybe executing different code in different ranks, please check if the code is branched and "
"operates on the global tensor."
;
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CONSTANT_H_
oneflow/core/common/container_util.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_
#define ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_
#include <vector>
#include "oneflow/core/common/hash_container.h"
#include "oneflow/core/common/type_traits.h"
#include "oneflow/core/common/maybe.h"
namespace
oneflow
{
template
<
typename
MapT
,
typename
KeyT
,
typename
U
>
scalar_or_const_ref_t
<
typename
MapT
::
mapped_type
>
MapAt
(
const
MapT
&
map
,
const
KeyT
&
key
,
const
U
&
default_val
)
{
const
auto
&
iter
=
map
.
find
(
key
);
if
(
iter
==
map
.
end
())
{
return
default_val
;
}
return
iter
->
second
;
}
template
<
typename
MapT
,
typename
KeyT
>
Maybe
<
scalar_or_const_ref_t
<
typename
MapT
::
mapped_type
>>
MapAt
(
const
MapT
&
map
,
const
KeyT
&
key
)
{
const
auto
&
iter
=
map
.
find
(
key
);
CHECK_OR_RETURN
(
iter
!=
map
.
end
());
return
iter
->
second
;
}
template
<
typename
MapT
,
typename
KeyT
>
Maybe
<
typename
MapT
::
mapped_type
&>
MapAt
(
MapT
&
map
,
const
KeyT
&
key
)
{
const
auto
&
iter
=
map
.
find
(
key
);
CHECK_OR_RETURN
(
iter
!=
map
.
end
());
return
iter
->
second
;
}
template
<
typename
VecT
>
Maybe
<
scalar_or_const_ref_t
<
typename
VecT
::
value_type
>>
VectorAt
(
const
VecT
&
vec
,
typename
VecT
::
size_type
index
)
{
CHECK_LT_OR_RETURN
(
index
,
vec
.
size
());
return
vec
[
index
];
}
template
<
typename
VecT
>
Maybe
<
typename
VecT
::
value_type
&>
VectorAt
(
VecT
&
vec
,
typename
VecT
::
size_type
index
)
{
static_assert
(
!
std
::
is_same
<
typename
VecT
::
value_type
,
bool
>::
value
,
"VectorAt(vector<bool>&, size_t) is not supported."
);
CHECK_LT_OR_RETURN
(
index
,
vec
.
size
());
return
vec
[
index
];
}
template
<
>
inline
Maybe
<
bool
>
VectorAt
(
const
std
::
vector
<
bool
>&
vec
,
typename
std
::
vector
<
bool
>::
size_type
index
)
{
CHECK_LT_OR_RETURN
(
index
,
vec
.
size
());
// convert vector bool proxy to bool
return
static_cast
<
bool
>
(
vec
[
index
]);
}
template
<
typename
T
>
std
::
string
Join
(
const
T
&
con
,
const
std
::
string
&
delimiter
)
{
std
::
ostringstream
os
;
auto
b
=
begin
(
con
),
e
=
end
(
con
);
if
(
b
!=
e
)
{
std
::
copy
(
b
,
prev
(
e
),
std
::
ostream_iterator
<
typename
T
::
value_type
>
(
os
,
delimiter
));
b
=
prev
(
e
);
}
if
(
b
!=
e
)
{
os
<<
*
b
;
}
return
os
.
str
();
}
template
<
typename
T
>
using
SmallSet
=
std
::
vector
<
T
>
;
template
<
typename
T
>
std
::
pair
<
typename
SmallSet
<
T
>::
iterator
,
bool
>
SmallSetInsert
(
SmallSet
<
T
>*
vec
,
const
T
&
elem
)
{
for
(
auto
iter
=
vec
->
begin
();
iter
!=
vec
->
end
();
++
iter
)
{
if
(
*
iter
==
elem
)
{
return
std
::
make_pair
(
iter
,
false
);
}
}
vec
->
push_back
(
elem
);
return
std
::
make_pair
(
--
vec
->
end
(),
true
);
}
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_
oneflow/core/common/container_util_test.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "gtest/gtest.h"
#include "oneflow/core/common/container_util.h"
namespace
oneflow
{
namespace
test
{
TEST
(
VectorAt
,
write_int_vector
)
{
std
::
vector
<
int
>
vec
=
{
1
,
2
,
3
,
4
,
5
};
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
1
)),
2
);
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
3
)),
4
);
CHECK_JUST
(
VectorAt
(
vec
,
1
))
=
6
;
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
1
)),
6
);
CHECK_JUST
(
VectorAt
(
vec
,
3
))
=
8
;
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
3
)),
8
);
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
0
)),
1
);
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
2
)),
3
);
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
4
)),
5
);
}
namespace
{
class
A
{
public:
explicit
A
(
int
a
)
:
a
(
a
)
{}
int
a
;
};
}
// namespace
TEST
(
VectorAt
,
write_custom_class_vector
)
{
std
::
vector
<
A
>
vec
=
{
A
(
1
),
A
(
2
)};
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
0
)).
a
,
1
);
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
1
)).
a
,
2
);
CHECK_JUST
(
VectorAt
(
vec
,
0
))
=
A
(
3
);
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
0
)).
a
,
3
);
CHECK_JUST
(
VectorAt
(
vec
,
1
))
=
A
(
4
);
EXPECT_EQ
(
CHECK_JUST
(
VectorAt
(
vec
,
1
)).
a
,
4
);
}
}
// namespace test
}
// namespace oneflow
oneflow/core/common/cplusplus_17.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
#define ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
#if __cplusplus < 201703L
#include <functional>
#include <numeric>
namespace
std
{
// a sequential version of inclusive_scan and exclusive_scan
template
<
class
InputIt
,
class
OutputIt
>
OutputIt
inclusive_scan
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
)
{
return
partial_sum
(
first
,
last
,
d_first
);
}
template
<
class
InputIt
,
class
OutputIt
,
class
BinaryOperation
>
OutputIt
inclusive_scan
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
,
BinaryOperation
binary_op
)
{
return
partial_sum
(
first
,
last
,
d_first
,
binary_op
);
}
template
<
class
InputIt
,
class
OutputIt
,
class
BinaryOperation
,
class
T
>
OutputIt
inclusive_scan
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
,
BinaryOperation
binary_op
,
T
init
)
{
// Based on https://en.cppreference.com/w/cpp/algorithm/partial_sum
if
(
first
==
last
)
return
d_first
;
typename
std
::
iterator_traits
<
InputIt
>::
value_type
sum
=
op
(
*
first
,
init
);
*
d_first
=
sum
;
while
(
++
first
!=
last
)
{
sum
=
binary_op
(
sum
,
*
first
);
*++
d_first
=
sum
;
}
return
++
d_first
;
}
template
<
class
InputIt
,
class
OutputIt
,
class
T
,
class
BinaryOperation
>
OutputIt
exclusive_scan
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
,
T
init
,
BinaryOperation
binary_op
)
{
if
(
first
==
last
)
return
d_first
;
typename
std
::
iterator_traits
<
InputIt
>::
value_type
sum
=
init
;
*
d_first
=
sum
;
first
--
;
last
--
;
while
(
++
first
!=
last
)
{
sum
=
binary_op
(
sum
,
*
first
);
*++
d_first
=
sum
;
}
return
++
d_first
;
}
template
<
class
InputIt
,
class
OutputIt
,
class
T
>
OutputIt
exclusive_scan
(
InputIt
first
,
InputIt
last
,
OutputIt
d_first
,
T
init
)
{
return
exclusive_scan
(
first
,
last
,
d_first
,
init
,
std
::
plus
<>
());
}
}
// namespace std
#endif
#endif // ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
oneflow/core/common/cplusplus_17_test.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <gtest/gtest.h>
#include <functional>
#include <iostream>
#include <iterator>
#include <vector>
#include "oneflow/core/common/cplusplus_17.h"
namespace
oneflow
{
namespace
test
{
TEST
(
Scan
,
scan
)
{
std
::
vector
<
int
>
data
{
3
,
1
,
4
,
1
,
5
,
9
,
2
,
6
};
std
::
vector
<
int
>
output
;
std
::
exclusive_scan
(
data
.
begin
(),
data
.
end
(),
std
::
back_insert_iterator
<
std
::
vector
<
int
>>
(
output
),
0
);
std
::
vector
<
int
>
ref_output
=
{
0
,
3
,
4
,
8
,
9
,
14
,
23
,
25
};
EXPECT_EQ
(
output
,
ref_output
);
output
.
clear
();
std
::
inclusive_scan
(
data
.
begin
(),
data
.
end
(),
std
::
back_insert_iterator
<
std
::
vector
<
int
>>
(
output
));
ref_output
=
{
3
,
4
,
8
,
9
,
14
,
23
,
25
,
31
};
EXPECT_EQ
(
output
,
ref_output
);
output
.
clear
();
std
::
exclusive_scan
(
data
.
begin
(),
data
.
end
(),
std
::
back_insert_iterator
<
std
::
vector
<
int
>>
(
output
),
1
,
std
::
multiplies
<>
{});
ref_output
=
{
1
,
3
,
3
,
12
,
12
,
60
,
540
,
1080
};
EXPECT_EQ
(
output
,
ref_output
);
output
.
clear
();
std
::
inclusive_scan
(
data
.
begin
(),
data
.
end
(),
std
::
back_insert_iterator
<
std
::
vector
<
int
>>
(
output
),
std
::
multiplies
<>
{});
ref_output
=
{
3
,
3
,
12
,
12
,
60
,
540
,
1080
,
6480
};
EXPECT_EQ
(
output
,
ref_output
);
output
.
clear
();
std
::
exclusive_scan
(
data
.
rbegin
(),
data
.
rend
(),
std
::
back_insert_iterator
<
std
::
vector
<
int
>>
(
output
),
1
,
std
::
multiplies
<>
{});
ref_output
=
{
1
,
6
,
12
,
108
,
540
,
540
,
2160
,
2160
};
EXPECT_EQ
(
output
,
ref_output
);
output
.
clear
();
}
}
// namespace test
}
// namespace oneflow
oneflow/core/common/cpp_attribute.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_
#define ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_
#include <glog/logging.h>
#define likely GOOGLE_PREDICT_TRUE
#define unlikely GOOGLE_PREDICT_FALSE
#endif // ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_
oneflow/core/common/data_type.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/tensor_buffer.h"
namespace
oneflow
{
bool
IsBoolDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
#define BOOL_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE
(
BOOL_CASE
,
BOOL_DATA_TYPE_SEQ
)
default:
return
false
;
}
#undef BOOL_CASE
}
bool
IsIntegralDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
#define INTEGRAL_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE
(
INTEGRAL_CASE
,
INT_DATA_TYPE_SEQ
UNSIGNED_INT_DATA_TYPE_SEQ
)
default:
return
false
;
}
#undef INTEGRAL_CASE
}
bool
IsFloatingDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
#define FLOATING_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE
(
FLOATING_CASE
,
FLOATING_DATA_TYPE_SEQ
)
default:
return
false
;
}
#undef FLOATING_CASE
}
bool
IsPODDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
#define POD_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE
(
POD_CASE
,
POD_DATA_TYPE_SEQ
)
default:
return
false
;
}
#undef POD_CASE
}
bool
IsPODAndHalfDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
#define POD_AND_HALF_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE
(
POD_AND_HALF_CASE
,
POD_AND_HALF_DATA_TYPE_SEQ
)
default:
return
false
;
}
#undef POD_AND_HALF_CASE
}
bool
IsIndexDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
#define INDEX_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE
(
INDEX_CASE
,
INDEX_DATA_TYPE_SEQ
)
default:
return
false
;
}
#undef INDEX_CASE
}
bool
IsSupportRequireGradDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
#define REQUIRE_GRAD_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE
(
REQUIRE_GRAD_CASE
,
FLOATING_DATA_TYPE_SEQ
FLOAT16_DATA_TYPE_SEQ
)
default:
return
false
;
}
#undef REQUIRE_GRAD_CASE
}
bool
NotSupportBoxingDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
#define NO_BOXING_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE
(
NO_BOXING_CASE
,
NO_BOXING_DATA_TYPE_SEQ
)
default:
return
false
;
}
#undef NO_BOXING_CASE
}
size_t
GetSizeOfDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
// 8-bit
case
kChar
:
return
1
;
case
kInt8
:
return
1
;
case
kUInt8
:
return
1
;
case
kBool
:
return
1
;
// 16-bit
case
kInt16
:
return
2
;
case
kUInt16
:
return
2
;
case
kFloat16
:
return
2
;
case
kBFloat16
:
return
2
;
// 32-bit
case
kInt32
:
return
4
;
case
kUInt32
:
return
4
;
case
kFloat
:
return
4
;
case
kComplex32
:
return
4
;
// 64-bit
case
kInt64
:
return
8
;
case
kUInt64
:
return
8
;
case
kDouble
:
return
8
;
case
kComplex64
:
return
8
;
// 128-bit
case
kInt128
:
return
16
;
case
kUInt128
:
return
16
;
case
kComplex128
:
return
16
;
// non pod
case
kOFRecord
:
return
sizeof
(
OFRecord
);
case
kTensorBuffer
:
return
sizeof
(
TensorBuffer
);
default:
LOG
(
FATAL
)
<<
"invalid data_type: "
<<
DataType_Name
(
data_type
);
}
}
namespace
{
void
CheckDataType
()
{
static_assert
(
sizeof
(
int8_t
)
==
sizeof
(
char
),
"sizeof(int8_t) != sizeof(char)"
);
static_assert
(
sizeof
(
int16_t
)
==
sizeof
(
short
),
"sizeof(int16_t) != sizeof(short)"
);
static_assert
(
sizeof
(
int32_t
)
==
sizeof
(
int
),
"sizeof(int32_t) != sizeof(int)"
);
static_assert
(
sizeof
(
int64_t
)
==
sizeof
(
long
long
),
"sizeof(int64_t) != sizeof(long long)"
);
#if defined(WITH_CUDA)
#define CHECK_DEVICE_FP16(get_val) \
do { \
float16 host_fp16 = get_val<float16>(); \
half device_fp16 = get_val<half>(); \
CHECK_EQ(*(uint16_t*)&host_fp16, *(uint16_t*)&device_fp16); \
} while (0)
CHECK_DEVICE_FP16
(
GetZeroVal
);
CHECK_DEVICE_FP16
(
GetOneVal
);
CHECK_DEVICE_FP16
(
GetMaxVal
);
CHECK_DEVICE_FP16
(
GetMinVal
);
#undef CHECK_DEVICE_FP16
#endif
#if defined(WITH_ROCM)
#define CHECK_DEVICE_FP16(get_val) \
do { \
float16 host_fp16 = get_val<float16>(); \
half device_fp16 = get_val<half>(); \
CHECK_EQ(*(uint16_t*)&host_fp16, *(uint16_t*)&device_fp16); \
} while (0)
CHECK_DEVICE_FP16
(
GetZeroVal
);
CHECK_DEVICE_FP16
(
GetOneVal
);
CHECK_DEVICE_FP16
(
GetMaxVal
);
CHECK_DEVICE_FP16
(
GetMinVal
);
#undef CHECK_DEVICE_FP16
#endif
#define CHECK_MAX_VAL(T, limit_value) CHECK_EQ(GetMaxVal<T>(), std::numeric_limits<T>::max());
OF_PP_FOR_EACH_TUPLE
(
CHECK_MAX_VAL
,
MAX_VAL_SEQ
);
#undef CHECK_MAX_VAL
#define CHECK_MIN_VAL(T, limit_value) CHECK_EQ(GetMinVal<T>(), std::numeric_limits<T>::lowest());
OF_PP_FOR_EACH_TUPLE
(
CHECK_MIN_VAL
,
MIN_VAL_SEQ
);
#undef CHECK_MIN_VAL
}
COMMAND
(
CheckDataType
());
}
// namespace
}
// namespace oneflow
oneflow/core/common/data_type.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_H_
#define ONEFLOW_CORE_COMMON_DATA_TYPE_H_
#include <cfloat>
#include <type_traits>
#if defined(WITH_CUDA)
#include <cuda_fp16.h>
#endif
#if defined(WITH_ROCM)
#include <hip/hip_fp16.h>
#endif
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/data_type_seq.h"
#include "oneflow/core/record/record.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/device_type.h"
#include <half.hpp>
namespace
oneflow
{
typedef
half_float
::
half
float16
;
template
<
typename
T
>
struct
IsFloat16
;
template
<
>
struct
IsFloat16
<
float16
>
:
std
::
true_type
{};
#ifdef WITH_CUDA
template
<
>
struct
IsFloat16
<
half
>
:
std
::
true_type
{};
#endif // WITH_CUDA
#ifdef WITH_ROCM
template
<
>
struct
IsFloat16
<
half
>
:
std
::
true_type
{};
#endif // WITH_ROCM
template
<
typename
T
>
struct
IsFloat16
:
std
::
false_type
{};
// Type Trait: IsFloating
template
<
typename
T
>
struct
IsFloating
:
std
::
integral_constant
<
bool
,
false
>
{};
#define SPECIALIZE_TRUE_FLOATING(type_cpp, type_proto) \
template<> \
struct IsFloating<type_cpp> : std::integral_constant<bool, true> {};
OF_PP_FOR_EACH_TUPLE
(
SPECIALIZE_TRUE_FLOATING
,
FLOATING_DATA_TYPE_SEQ
);
#undef SPECIALIZE_TRUE_FLOATING
// Type Trait: IsIntegral
template
<
typename
T
>
struct
IsIntegral
:
std
::
integral_constant
<
bool
,
false
>
{};
#define SPECIALIZE_TRUE_INTEGRAL(type_cpp, type_proto) \
template<> \
struct IsIntegral<type_cpp> : std::integral_constant<bool, true> {};
OF_PP_FOR_EACH_TUPLE
(
SPECIALIZE_TRUE_INTEGRAL
,
INT_DATA_TYPE_SEQ
);
#undef SPECIALIZE_TRUE_INTEGRAL
// Type Trait: IsUnsignedIntegral
template
<
typename
T
>
struct
IsUnsignedIntegral
:
std
::
integral_constant
<
bool
,
false
>
{};
#define SPECIALIZE_TRUE_INTEGRAL(type_cpp, type_proto) \
template<> \
struct IsUnsignedIntegral<type_cpp> : std::integral_constant<bool, true> {};
OF_PP_FOR_EACH_TUPLE
(
SPECIALIZE_TRUE_INTEGRAL
,
UNSIGNED_INT_DATA_TYPE_SEQ
);
#undef SPECIALIZE_TRUE_INTEGRAL
// Type Trait: GetDataType
template
<
typename
T
,
typename
T2
=
void
>
struct
GetDataType
;
template
<
>
struct
GetDataType
<
void
>
:
std
::
integral_constant
<
DataType
,
DataType
::
kChar
>
{};
#define SPECIALIZE_GET_DATA_TYPE(type_cpp, type_proto) \
template<> \
struct GetDataType<type_cpp> : std::integral_constant<DataType, type_proto> {}; \
inline type_cpp GetTypeByDataType(std::integral_constant<DataType, type_proto>) { return {}; }
OF_PP_FOR_EACH_TUPLE
(
SPECIALIZE_GET_DATA_TYPE
,
ALL_DATA_TYPE_SEQ
FLOAT16_DATA_TYPE_SEQ
);
#undef SPECIALIZE_GET_DATA_TYPE
template
<
typename
T
>
struct
GetDataType
<
T
,
typename
std
::
enable_if
<
IsFloat16
<
T
>::
value
>::
type
>
:
std
::
integral_constant
<
DataType
,
DataType
::
kFloat16
>
{};
template
<
DataType
type
>
using
DataTypeToType
=
decltype
(
GetTypeByDataType
(
std
::
integral_constant
<
DataType
,
type
>
{}));
#if defined(__CUDACC__)
#define OF_DEVICE_FUNC __device__ __host__ __forceinline__
#elif defined(__HIPCC__)
#define OF_DEVICE_FUNC __device__ __host__ __forceinline__
#else
#define OF_DEVICE_FUNC inline
#endif
template
<
typename
T
,
typename
std
::
enable_if
<!
IsFloat16
<
T
>
::
value
>::
type
*
=
nullptr
>
OF_DEVICE_FUNC
T
GetZeroVal
()
{
return
static_cast
<
T
>
(
0
);
}
template
<
typename
T
,
typename
std
::
enable_if
<!
IsFloat16
<
T
>
::
value
>::
type
*
=
nullptr
>
OF_DEVICE_FUNC
T
GetOneVal
()
{
return
static_cast
<
T
>
(
1
);
}
template
<
typename
T
,
typename
std
::
enable_if
<!
IsFloat16
<
T
>
::
value
>::
type
*
=
nullptr
>
OF_DEVICE_FUNC
T
GetMinVal
();
template
<
typename
T
,
typename
std
::
enable_if
<!
IsFloat16
<
T
>
::
value
>::
type
*
=
nullptr
>
OF_DEVICE_FUNC
T
GetMaxVal
();
#ifdef __APPLE__
#define APPLE_MAX_VAL_SEQ OF_PP_MAKE_TUPLE_SEQ(unsigned long, ULONG_MAX)
#else
#define APPLE_MAX_VAL_SEQ
#endif
#define MAX_VAL_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int8_t, INT8_MAX) \
OF_PP_MAKE_TUPLE_SEQ(int16_t, INT16_MAX) \
OF_PP_MAKE_TUPLE_SEQ(int32_t, INT32_MAX) \
OF_PP_MAKE_TUPLE_SEQ(int64_t, INT64_MAX) \
OF_PP_MAKE_TUPLE_SEQ(uint8_t, UINT8_MAX) \
OF_PP_MAKE_TUPLE_SEQ(uint16_t, UINT16_MAX) \
OF_PP_MAKE_TUPLE_SEQ(uint32_t, UINT32_MAX) \
APPLE_MAX_VAL_SEQ \
OF_PP_MAKE_TUPLE_SEQ(uint64_t, UINT64_MAX) \
OF_PP_MAKE_TUPLE_SEQ(float, FLT_MAX) \
OF_PP_MAKE_TUPLE_SEQ(double, DBL_MAX) \
OF_PP_MAKE_TUPLE_SEQ(bool, true)
#ifdef __APPLE__
#define APPLE_MIN_VAL_SEQ OF_PP_MAKE_TUPLE_SEQ(unsigned long, 0)
#else
#define APPLE_MIN_VAL_SEQ
#endif
#define MIN_VAL_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int8_t, INT8_MIN) \
OF_PP_MAKE_TUPLE_SEQ(int16_t, INT16_MIN) \
OF_PP_MAKE_TUPLE_SEQ(int32_t, INT32_MIN) \
OF_PP_MAKE_TUPLE_SEQ(int64_t, INT64_MIN) \
OF_PP_MAKE_TUPLE_SEQ(uint8_t, 0) \
OF_PP_MAKE_TUPLE_SEQ(uint16_t, 0) \
OF_PP_MAKE_TUPLE_SEQ(uint32_t, 0) \
APPLE_MIN_VAL_SEQ \
OF_PP_MAKE_TUPLE_SEQ(uint64_t, 0) \
OF_PP_MAKE_TUPLE_SEQ(float, -FLT_MAX) \
OF_PP_MAKE_TUPLE_SEQ(double, -DBL_MAX) \
OF_PP_MAKE_TUPLE_SEQ(bool, false)
#define SPECIALIZE_MAX_VAL(T, limit_value) \
template<> \
OF_DEVICE_FUNC T GetMaxVal<T>() { \
return limit_value; \
}
OF_PP_FOR_EACH_TUPLE
(
SPECIALIZE_MAX_VAL
,
MAX_VAL_SEQ
);
#undef SPECIALIZE_MAX_VAL
#define SPECIALIZE_MIN_VAL(T, limit_value) \
template<> \
OF_DEVICE_FUNC T GetMinVal<T>() { \
return limit_value; \
}
OF_PP_FOR_EACH_TUPLE
(
SPECIALIZE_MIN_VAL
,
MIN_VAL_SEQ
);
#undef SPECIALIZE_MIN_VAL
template
<
typename
T
>
const
T
*
GetZeroPtr
()
{
static
const
T
ret
=
GetZeroVal
<
T
>
();
return
&
ret
;
}
template
<
typename
T
>
const
T
*
GetOnePtr
()
{
static
const
T
ret
=
GetOneVal
<
T
>
();
return
&
ret
;
}
template
<
typename
T
,
typename
std
::
enable_if
<
IsFloat16
<
T
>
::
value
>::
type
*
=
nullptr
>
OF_DEVICE_FUNC
T
GetZeroVal
()
{
uint16_t
ret
=
0x0
;
// Decimal: 0; Binary: 0 00000 0000000000
return
*
(
T
*
)
&
ret
;
}
template
<
typename
T
,
typename
std
::
enable_if
<
IsFloat16
<
T
>
::
value
>::
type
*
=
nullptr
>
OF_DEVICE_FUNC
T
GetOneVal
()
{
uint16_t
ret
=
0x3c00
;
// Decimal: 15360; Binary: 0 01111 0000000000
return
*
(
T
*
)
&
ret
;
}
template
<
typename
T
,
typename
std
::
enable_if
<
IsFloat16
<
T
>
::
value
>::
type
*
=
nullptr
>
OF_DEVICE_FUNC
T
GetMaxVal
()
{
uint16_t
ret
=
0x7bff
;
// Decimal: 31743; Binary: 0 11110 1111111111
return
*
(
T
*
)
&
ret
;
}
template
<
typename
T
,
typename
std
::
enable_if
<
IsFloat16
<
T
>
::
value
>::
type
*
=
nullptr
>
OF_DEVICE_FUNC
T
GetMinVal
()
{
uint16_t
ret
=
0xfbff
;
// Decimal: 64511; Binary: 1 11110 1111111111
return
*
(
T
*
)
&
ret
;
}
template
<
DeviceType
,
typename
T
>
struct
DevDType
{
typedef
T
type
;
};
#if defined(WITH_CUDA)
template
<
>
struct
DevDType
<
DeviceType
::
kCUDA
,
float16
>
{
static_assert
(
sizeof
(
float16
)
==
sizeof
(
half
),
"sizeof(float16) != sizeof(half)"
);
typedef
half
type
;
};
#endif
#if defined(WITH_ROCM)
template
<
>
struct
DevDType
<
DeviceType
::
kCUDA
,
float16
>
{
static_assert
(
sizeof
(
float16
)
==
sizeof
(
half
),
"sizeof(float16) != sizeof(half)"
);
typedef
half
type
;
};
#endif
// Func
bool
IsBoolDataType
(
DataType
data_type
);
bool
IsIntegralDataType
(
DataType
data_type
);
bool
IsFloatingDataType
(
DataType
data_type
);
bool
IsSupportRequireGradDataType
(
DataType
data_type
);
bool
IsPODDataType
(
DataType
data_type
);
bool
IsPODAndHalfDataType
(
DataType
data_type
);
bool
IsIndexDataType
(
DataType
data_type
);
bool
NotSupportBoxingDataType
(
DataType
data_type
);
size_t
GetSizeOfDataType
(
DataType
data_type
);
inline
bool
operator
==
(
const
OptInt64
&
lhs
,
const
OptInt64
&
rhs
)
{
return
(
lhs
.
has_value
()
&&
rhs
.
has_value
()
&&
lhs
.
value
()
==
rhs
.
value
())
||
(
!
lhs
.
has_value
()
&&
!
rhs
.
has_value
());
}
template
<
typename
T
>
void
CheckDataType
(
DataType
data_type
)
{
LOG_IF
(
FATAL
,
(
std
::
is_same
<
T
,
void
>::
value
==
false
&&
std
::
is_same
<
T
,
char
>::
value
==
false
&&
std
::
is_same
<
T
,
long
>::
value
==
false
&&
data_type
!=
DataType
::
kChar
&&
data_type
!=
GetDataType
<
T
>::
value
))
<<
data_type
<<
" "
<<
GetDataType
<
T
>::
value
;
}
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DATA_TYPE_H_
oneflow/core/common/data_type.proto
0 → 100644
View file @
21d47d0e
syntax
=
"proto2"
;
package
oneflow
;
enum
DataType
{
kInvalidDataType
=
0
;
kChar
=
1
;
kFloat
=
2
;
kDouble
=
3
;
kInt8
=
4
;
kInt32
=
5
;
kInt64
=
6
;
kUInt8
=
7
;
kOFRecord
=
8
;
kFloat16
=
9
;
kTensorBuffer
=
10
;
kBFloat16
=
11
;
kBool
=
12
;
kUInt16
=
13
;
kUInt32
=
14
;
kUInt64
=
15
;
kUInt128
=
16
;
kInt16
=
17
;
kInt128
=
18
;
kComplex32
=
19
;
kComplex64
=
20
;
kComplex128
=
21
;
}
message
OptInt64
{
optional
int64
value
=
1
[
default
=
-
1
];
}
oneflow/core/common/data_type_converter.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_
#define ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_
#ifdef WITH_CUDA
#include <cuda_runtime.h>
#endif
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#endif
#include <cstdint>
#include <limits>
#include <type_traits>
#include "oneflow/core/common/data_type.h"
namespace
oneflow
{
template
<
typename
T
>
struct
IsFloatingOrHalf
{
static
const
bool
value
=
IsFloating
<
T
>::
value
||
IsFloat16
<
T
>::
value
;
};
template
<
typename
T
>
struct
IsArithmeticOrHalf
{
static
const
bool
value
=
std
::
is_arithmetic
<
T
>::
value
||
IsFloat16
<
T
>::
value
;
};
template
<
typename
From
,
typename
To
>
struct
NeedsClamp
{
static
const
bool
from_fp
=
IsFloatingOrHalf
<
From
>::
value
;
static
const
bool
to_fp
=
IsFloatingOrHalf
<
To
>::
value
;
static
const
bool
from_fp16
=
IsFloat16
<
From
>::
value
;
static
const
bool
to_fp16
=
IsFloat16
<
To
>::
value
;
static
const
bool
from_unsigned
=
std
::
is_unsigned
<
From
>::
value
;
static
const
bool
to_unsigned
=
std
::
is_unsigned
<
To
>::
value
;
static
const
bool
value
=
// to smaller type of same kind (fp, int)
(
from_fp
==
to_fp
&&
sizeof
(
To
)
<
sizeof
(
From
))
||
// fp32 has range in excess of (u)int64
(
from_fp
&&
!
to_fp
)
||
// converting to unsigned requires clamping negatives to zero
(
!
from_unsigned
&&
to_unsigned
)
||
// zero-extending signed unsigned integers requires more bits
(
from_unsigned
&&
!
to_unsigned
&&
sizeof
(
To
)
<=
sizeof
(
From
))
||
// float16
(
to_fp16
&&
sizeof
(
To
)
<=
sizeof
(
From
));
};
template
<
typename
To
>
struct
NeedsClamp
<
bool
,
To
>
{
static
const
bool
value
=
false
;
};
template
<
typename
T
,
typename
U
,
typename
Enabled
=
void
>
struct
ClampHelper
{};
// floating-point and signed integer -> floating-point and signed integer
template
<
typename
T
,
typename
U
>
struct
ClampHelper
<
T
,
U
,
std
::
enable_if_t
<
NeedsClamp
<
U
,
T
>::
value
&&
std
::
is_signed
<
U
>::
value
&&
std
::
is_signed
<
T
>::
value
,
void
>>
{
OF_DEVICE_FUNC
static
const
T
Call
(
U
value
)
{
return
value
<=
GetMinVal
<
T
>
()
?
GetMinVal
<
T
>
()
:
value
>=
GetMaxVal
<
T
>
()
?
GetMaxVal
<
T
>
()
:
static_cast
<
T
>
(
value
);
}
};
// floating-point -> unsigned types
template
<
typename
T
,
typename
U
>
struct
ClampHelper
<
T
,
U
,
std
::
enable_if_t
<
NeedsClamp
<
U
,
T
>::
value
&&
std
::
is_signed
<
U
>::
value
&&
IsFloatingOrHalf
<
U
>::
value
&&
std
::
is_unsigned
<
T
>::
value
,
void
>>
{
OF_DEVICE_FUNC
static
const
T
Call
(
U
value
)
{
return
value
<=
GetMinVal
<
T
>
()
?
GetMinVal
<
T
>
()
:
value
>=
GetMaxVal
<
T
>
()
?
GetMaxVal
<
T
>
()
:
static_cast
<
T
>
(
value
);
}
};
// signed integer types -> unsigned types
template
<
typename
T
,
typename
U
>
struct
ClampHelper
<
T
,
U
,
std
::
enable_if_t
<
NeedsClamp
<
U
,
T
>::
value
&&
std
::
is_signed
<
U
>::
value
&&
std
::
is_integral
<
U
>::
value
&&
std
::
is_unsigned
<
T
>::
value
,
void
>>
{
OF_DEVICE_FUNC
static
const
T
Call
(
U
value
)
{
return
value
<=
0
?
0
:
static_cast
<
std
::
make_unsigned_t
<
U
>>
(
value
)
>=
GetMaxVal
<
T
>
()
?
GetMaxVal
<
T
>
()
:
static_cast
<
T
>
(
value
);
}
};
// unsigned types -> any types
template
<
typename
T
,
typename
U
>
struct
ClampHelper
<
T
,
U
,
std
::
enable_if_t
<
NeedsClamp
<
U
,
T
>::
value
&&
std
::
is_unsigned
<
U
>::
value
,
void
>>
{
OF_DEVICE_FUNC
static
const
T
Call
(
U
value
)
{
return
value
>=
GetMaxVal
<
T
>
()
?
GetMaxVal
<
T
>
()
:
static_cast
<
T
>
(
value
);
}
};
// not clamp
template
<
typename
T
,
typename
U
>
struct
ClampHelper
<
T
,
U
,
std
::
enable_if_t
<!
NeedsClamp
<
U
,
T
>::
value
,
void
>>
{
OF_DEVICE_FUNC
static
const
T
Call
(
U
value
)
{
return
value
;
}
};
OF_DEVICE_FUNC
const
int32_t
Clamp
(
uint32_t
value
)
{
return
value
&
0x80000000u
?
0x7fffffff
:
value
;
}
OF_DEVICE_FUNC
const
uint32_t
Clamp
(
int32_t
value
)
{
return
value
<
0
?
0u
:
value
;
}
OF_DEVICE_FUNC
const
int32_t
Clamp
(
int64_t
value
)
{
return
value
<
static_cast
<
int64_t
>
(
GetMinVal
<
int32_t
>
())
?
GetMinVal
<
int32_t
>
()
:
value
>
static_cast
<
int64_t
>
(
GetMaxVal
<
int32_t
>
())
?
GetMaxVal
<
int32_t
>
()
:
static_cast
<
int32_t
>
(
value
);
}
template
<
>
struct
ClampHelper
<
int32_t
,
uint64_t
>
{
OF_DEVICE_FUNC
static
const
int32_t
Call
(
uint64_t
value
)
{
return
value
>
static_cast
<
uint64_t
>
(
GetMaxVal
<
int32_t
>
())
?
GetMaxVal
<
int32_t
>
()
:
static_cast
<
int32_t
>
(
value
);
}
};
template
<
>
struct
ClampHelper
<
uint32_t
,
int64_t
>
{
OF_DEVICE_FUNC
static
const
uint32_t
Call
(
int64_t
value
)
{
return
value
<
0
?
0
:
value
>
static_cast
<
int64_t
>
(
GetMaxVal
<
uint32_t
>
())
?
GetMaxVal
<
uint32_t
>
()
:
static_cast
<
uint32_t
>
(
value
);
}
};
template
<
>
struct
ClampHelper
<
uint32_t
,
uint64_t
>
{
OF_DEVICE_FUNC
static
const
uint32_t
Call
(
uint64_t
value
)
{
return
value
>
static_cast
<
uint64_t
>
(
GetMaxVal
<
uint32_t
>
())
?
GetMaxVal
<
uint32_t
>
()
:
static_cast
<
uint32_t
>
(
value
);
}
};
template
<
typename
T
>
struct
ClampHelper
<
bool
,
T
>
{
OF_DEVICE_FUNC
static
const
bool
Call
(
T
value
)
{
return
static_cast
<
bool
>
(
value
);
}
};
template
<
typename
T
>
struct
ClampHelper
<
float16
,
T
>
{
inline
static
const
float16
Call
(
T
value
)
{
return
static_cast
<
float16
>
(
ClampHelper
<
T
,
float
>::
Call
(
value
)
<
GetMinVal
<
float16
>
()
?
GetMinVal
<
float16
>
()
:
ClampHelper
<
T
,
float
>::
Call
(
value
)
>
GetMaxVal
<
float16
>
()
?
GetMaxVal
<
float16
>
()
:
ClampHelper
<
T
,
float
>::
Call
(
value
));
}
};
template
<
typename
T
>
struct
ClampHelper
<
T
,
float16
>
{
inline
static
const
T
Call
(
float16
value
)
{
return
ClampHelper
<
T
,
float
>::
Call
(
static_cast
<
float
>
(
value
));
}
};
inline
const
float16
Clamp
(
float16
value
)
{
return
value
;
}
template
<
typename
T
,
typename
U
>
OF_DEVICE_FUNC
const
T
Clamp
(
U
value
)
{
return
ClampHelper
<
T
,
U
>::
Call
(
value
);
}
namespace
{
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
inline
__device__
int
cuda_round_helper
(
float
f
,
int
)
{
return
__float2int_rn
(
f
);
}
inline
__device__
unsigned
cuda_round_helper
(
float
f
,
unsigned
)
{
return
__float2uint_rn
(
f
);
}
inline
__device__
long
long
cuda_round_helper
(
float
f
,
long
long
)
{
return
__float2ll_rd
(
f
+
0.5
f
);
}
inline
__device__
unsigned
long
long
cuda_round_helper
(
float
f
,
unsigned
long
long
)
{
return
__float2ull_rd
(
f
+
0.5
f
);
}
inline
__device__
long
cuda_round_helper
(
float
f
,
long
)
{
return
sizeof
(
long
)
==
sizeof
(
int
)
?
__float2int_rn
(
f
)
:
__float2ll_rd
(
f
+
0.5
f
);
}
inline
__device__
unsigned
long
cuda_round_helper
(
float
f
,
unsigned
long
)
{
return
sizeof
(
unsigned
long
)
==
sizeof
(
unsigned
int
)
?
__float2uint_rn
(
f
)
:
__float2ull_rd
(
f
+
0.5
f
);
}
inline
__device__
int
cuda_round_helper
(
double
f
,
int
)
{
return
__double2int_rn
(
f
);
}
inline
__device__
unsigned
cuda_round_helper
(
double
f
,
unsigned
)
{
return
__double2uint_rn
(
f
);
}
inline
__device__
long
long
cuda_round_helper
(
double
f
,
long
long
)
{
return
__double2ll_rd
(
f
+
0.5
f
);
}
inline
__device__
unsigned
long
long
cuda_round_helper
(
double
f
,
unsigned
long
long
)
{
return
__double2ull_rd
(
f
+
0.5
f
);
}
inline
__device__
long
cuda_round_helper
(
double
f
,
long
)
{
return
sizeof
(
long
)
==
sizeof
(
int
)
?
__double2int_rn
(
f
)
:
__double2ll_rd
(
f
+
0.5
f
);
}
inline
__device__
unsigned
long
cuda_round_helper
(
double
f
,
unsigned
long
)
{
return
sizeof
(
unsigned
long
)
==
sizeof
(
unsigned
int
)
?
__double2uint_rn
(
f
)
:
__double2ull_rd
(
f
+
0.5
f
);
}
#endif
template
<
typename
Out
,
typename
In
,
bool
OutIsFp
=
IsFloatingOrHalf
<
Out
>
::
value
,
bool
InIsFp
=
IsFloatingOrHalf
<
In
>::
value
>
struct
ConverterBase
;
template
<
typename
Out
,
typename
In
>
struct
Converter
:
ConverterBase
<
Out
,
In
>
{
static_assert
(
IsArithmeticOrHalf
<
Out
>::
value
&&
IsArithmeticOrHalf
<
In
>::
value
,
"Default ConverterBase can only be used with arithmetic types."
);
};
// Converts between two FP types
template
<
typename
Out
,
typename
In
>
struct
ConverterBase
<
Out
,
In
,
true
,
true
>
{
OF_DEVICE_FUNC
static
const
Out
Convert
(
In
value
)
{
return
value
;
}
OF_DEVICE_FUNC
static
const
Out
ConvertNorm
(
In
value
)
{
return
value
;
}
OF_DEVICE_FUNC
static
const
Out
ConvertSat
(
In
value
)
{
return
value
;
}
OF_DEVICE_FUNC
static
const
Out
ConvertSatNorm
(
In
value
)
{
return
value
;
}
};
// Converts integral to FP type
template
<
typename
Out
,
typename
In
>
struct
ConverterBase
<
Out
,
In
,
true
,
false
>
{
OF_DEVICE_FUNC
static
const
Out
Convert
(
In
value
)
{
return
value
;
}
OF_DEVICE_FUNC
static
const
Out
ConvertSat
(
In
value
)
{
return
value
;
}
OF_DEVICE_FUNC
static
const
Out
ConvertNorm
(
In
value
)
{
return
value
*
(
Out
(
1
)
/
(
GetMaxVal
<
In
>
()));
}
OF_DEVICE_FUNC
static
const
Out
ConvertSatNorm
(
In
value
)
{
return
value
*
(
Out
(
1
)
/
(
GetMaxVal
<
In
>
()));
}
};
// Converts integral to float16
template
<
typename
In
>
struct
ConverterBase
<
float16
,
In
,
true
,
false
>
{
OF_DEVICE_FUNC
static
const
float16
Convert
(
In
value
)
{
auto
out
=
ConverterBase
<
float
,
In
,
true
,
false
>::
Convert
(
value
);
return
static_cast
<
float16
>
(
out
);
}
OF_DEVICE_FUNC
static
const
float16
ConvertSat
(
In
value
)
{
auto
out
=
ConverterBase
<
float
,
In
,
true
,
false
>::
ConvertSat
(
value
);
return
static_cast
<
float16
>
(
out
);
}
OF_DEVICE_FUNC
static
const
float16
ConvertNorm
(
In
value
)
{
auto
out
=
ConverterBase
<
float
,
In
,
true
,
false
>::
ConvertNorm
(
value
);
return
static_cast
<
float16
>
(
out
);
}
OF_DEVICE_FUNC
static
const
float16
ConvertSatNorm
(
In
value
)
{
auto
out
=
ConverterBase
<
float
,
In
,
true
,
false
>::
ConvertSatNorm
(
value
);
return
static_cast
<
float16
>
(
out
);
}
};
// Converts FP to integral type
template
<
typename
Out
,
typename
In
>
struct
ConverterBase
<
Out
,
In
,
false
,
true
>
{
OF_DEVICE_FUNC
static
const
Out
Convert
(
In
value
)
{
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return
Clamp
<
Out
>
(
cuda_round_helper
(
value
,
Out
()));
#else
return
Clamp
<
Out
>
(
std
::
round
(
value
));
#endif
}
OF_DEVICE_FUNC
static
const
Out
ConvertSat
(
In
value
)
{
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return
Clamp
<
Out
>
(
cuda_round_helper
(
value
,
Out
()));
#else
return
Clamp
<
Out
>
(
std
::
round
(
value
));
#endif
}
OF_DEVICE_FUNC
static
const
Out
ConvertNorm
(
In
value
)
{
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return
Clamp
<
Out
>
(
cuda_round_helper
(
value
*
GetMaxVal
<
Out
>
(),
Out
()));
#else
return
std
::
round
(
value
*
GetMaxVal
<
Out
>
());
#endif
}
OF_DEVICE_FUNC
static
const
Out
ConvertSatNorm
(
In
value
)
{
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return
std
::
is_signed
<
Out
>::
value
?
Clamp
<
Out
>
(
cuda_round_helper
(
value
*
GetMaxVal
<
Out
>
(),
Out
()))
:
cuda_round_helper
(
GetMaxVal
<
Out
>
()
*
__saturatef
(
value
),
Out
());
#else
return
Clamp
<
Out
>
(
std
::
round
(
value
*
GetMaxVal
<
Out
>
()));
#endif
}
};
// Converts signed to signed, unsigned to unsigned or unsigned to signed
template
<
typename
Out
,
typename
In
,
bool
IsOutSigned
=
std
::
is_signed
<
Out
>
::
value
,
bool
IsInSigned
=
std
::
is_signed
<
In
>::
value
>
struct
ConvertIntInt
{
OF_DEVICE_FUNC
static
const
Out
Convert
(
In
value
)
{
return
value
;
}
OF_DEVICE_FUNC
static
const
Out
ConvertNorm
(
In
value
)
{
return
Converter
<
Out
,
float
>::
Convert
(
value
*
(
1.0
f
*
GetMaxVal
<
Out
>
()
/
GetMaxVal
<
In
>
()));
}
OF_DEVICE_FUNC
static
const
Out
ConvertSat
(
In
value
)
{
return
Clamp
<
Out
>
(
value
);
}
OF_DEVICE_FUNC
static
const
Out
ConvertSatNorm
(
In
value
)
{
return
ConvertNorm
(
value
);
}
};
// Converts signed to unsigned integer
template
<
typename
Out
,
typename
In
>
struct
ConvertIntInt
<
Out
,
In
,
false
,
true
>
{
OF_DEVICE_FUNC
static
const
Out
Convert
(
In
value
)
{
return
value
;
}
OF_DEVICE_FUNC
static
const
Out
ConvertNorm
(
In
value
)
{
return
Converter
<
Out
,
float
>::
Convert
(
value
*
(
1.0
f
*
GetMaxVal
<
Out
>
()
/
GetMaxVal
<
In
>
()));
}
OF_DEVICE_FUNC
static
const
Out
ConvertSat
(
In
value
)
{
return
Clamp
<
Out
>
(
value
);
}
OF_DEVICE_FUNC
static
const
Out
ConvertSatNorm
(
In
value
)
{
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return
cuda_round_helper
(
__saturatef
(
value
*
(
1.0
f
/
GetMaxVal
<
In
>
()))
*
GetMaxVal
<
Out
>
());
}
#else
return
value
<
0
?
0
:
ConvertNorm
(
value
);
}
#endif
};
// Converts between integral types
template
<
typename
Out
,
typename
In
>
struct
ConverterBase
<
Out
,
In
,
false
,
false
>
:
ConvertIntInt
<
Out
,
In
>
{
static_assert
(
IsArithmeticOrHalf
<
Out
>::
value
&&
IsArithmeticOrHalf
<
In
>::
value
,
"Default ConverterBase can only be used with arithmetic types."
);
};
// Pass-through conversion
template
<
typename
T
>
struct
Converter
<
T
,
T
>
{
static
OF_DEVICE_FUNC
const
T
Convert
(
T
value
)
{
return
value
;
}
static
OF_DEVICE_FUNC
const
T
ConvertSat
(
T
value
)
{
return
value
;
}
static
OF_DEVICE_FUNC
const
T
ConvertNorm
(
T
value
)
{
return
value
;
}
static
OF_DEVICE_FUNC
const
T
ConvertSatNorm
(
T
value
)
{
return
value
;
}
};
template
<
typename
raw_out
,
typename
raw_in
>
using
converter_t
=
Converter
<
std
::
remove_cv_t
<
raw_out
>
,
std
::
remove_cv_t
<
std
::
remove_reference_t
<
raw_in
>>>
;
}
// namespace
template
<
typename
Out
,
typename
In
>
OF_DEVICE_FUNC
const
Out
Convert
(
In
value
)
{
return
converter_t
<
Out
,
In
>::
Convert
(
value
);
}
template
<
typename
Out
,
typename
In
>
OF_DEVICE_FUNC
const
Out
ConvertNorm
(
In
value
)
{
return
converter_t
<
Out
,
In
>::
ConvertNorm
(
value
);
}
template
<
typename
Out
,
typename
In
>
OF_DEVICE_FUNC
const
Out
ConvertSat
(
In
value
)
{
return
converter_t
<
Out
,
In
>::
ConvertSat
(
value
);
}
template
<
typename
Out
,
typename
In
>
OF_DEVICE_FUNC
const
Out
ConvertSatNorm
(
In
value
)
{
return
converter_t
<
Out
,
In
>::
ConvertSatNorm
(
value
);
}
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_
oneflow/core/common/data_type_converter_test.cpp
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "gtest/gtest.h"
#include "util.h"
#include "oneflow/core/common/data_type_converter.h"
#include "oneflow/core/common/data_type_converter_test_static.h"
#ifdef __CUDA_ARCH__
#include <cuda_runtime.h>
#elif defined(__HIP_DEVICE_COMPILE__)
#include <hip/hip_runtime.h>
#else
#include <cmath>
#endif
namespace
oneflow
{
namespace
{
// cpp17 std::clamp possible implementation
template
<
class
T
>
constexpr
const
T
&
clamp
(
const
T
&
v
,
const
T
&
lo
,
const
T
&
hi
)
{
return
(
v
<
lo
)
?
lo
:
(
hi
<
v
)
?
hi
:
v
;
}
}
// namespace
TEST
(
ClampTest
,
Clamp
)
{
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
0
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
255
)
==
255
);
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
100
)
==
100
);
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
100.3
)
==
100
);
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
256
)
==
255
);
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
-
4
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
-
4.0
f
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
1e+20
f
)
==
255
);
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
-
1e+20
f
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
1e+200
)
==
255
);
ASSERT_TRUE
(
Clamp
<
uint8_t
>
(
-
1e+200
)
==
0
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
-
4
)
==
-
4
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
-
4.2
)
==
-
4
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
4.2
)
==
4
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
127
)
==
127
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
128
)
==
127
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
256
)
==
127
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
-
128
)
==
-
128
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
-
256
)
==
-
128
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
1e+20
f
)
==
127
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
-
1e+20
f
)
==
-
128
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
1e+200
)
==
127
);
ASSERT_TRUE
(
Clamp
<
int8_t
>
(
-
1e+200
)
==
-
128
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
0
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
0xffff
)
==
0xffff
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
100
)
==
100
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
100.3
)
==
100
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
0x10000
)
==
0xffff
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
-
4
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
-
4.0
f
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
1e+20
f
)
==
0xffff
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
-
1e+20
f
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
1e+200
)
==
0xffff
);
ASSERT_TRUE
(
Clamp
<
uint16_t
>
(
-
1e+200
)
==
0
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
-
4
)
==
-
4
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
-
4.2
)
==
-
4
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
4.2
)
==
4
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
0x7fff
)
==
0x7fff
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
0x8000
)
==
0x7fff
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
0x10000
)
==
0x7fff
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
-
0x8000
)
==
-
0x8000
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
-
0x10000
)
==
-
0x8000
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
1e+20
f
)
==
0x7fff
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
-
1e+20
f
)
==
-
0x8000
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
1e+200
)
==
0x7fff
);
ASSERT_TRUE
(
Clamp
<
int16_t
>
(
-
1e+200
)
==
-
0x8000
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
0
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
0xffffffffLL
)
==
0xffffffffLL
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
100
)
==
100
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
100.3
)
==
100
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
0x100000000LL
)
==
0xffffffffLL
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
-
4
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
-
4.0
f
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
1e+20
f
)
==
0xffffffffu
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
-
1.0e+20
f
)
==
0
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
1e+200
)
==
0xffffffffu
);
ASSERT_TRUE
(
Clamp
<
uint32_t
>
(
-
1.0e+200
)
==
0
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
-
4
)
==
-
4
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
-
4LL
)
==
-
4
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
-
4.2
)
==
-
4
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
4.2
)
==
4
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
0x7fffffff
)
==
0x7fffffff
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
0x80000000L
)
==
0x7fffffff
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
0x100000000L
)
==
0x7fffffff
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
-
0x80000000LL
)
==
-
0x7fffffff
-
1
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
-
0x100000000LL
)
==
-
0x7fffffff
-
1
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
1.0e+20
f
)
==
0x7fffffff
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
-
1.0e+20
f
)
==
-
0x80000000L
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
1.0e+200
)
==
0x7fffffff
);
ASSERT_TRUE
(
Clamp
<
int32_t
>
(
-
1.0e+200
)
==
-
0x80000000L
);
ASSERT_TRUE
(
Clamp
<
int64_t
>
(
1.0e+200
)
==
0x7fffffffffffffffLL
);
ASSERT_TRUE
(
Clamp
<
int64_t
>
(
-
1.0e+200
)
==
-
0x7fffffffffffffffLL
-
1
);
ASSERT_TRUE
(
Clamp
<
uint64_t
>
(
1.0e+200
)
==
0xffffffffffffffffULL
);
ASSERT_TRUE
(
Clamp
<
uint64_t
>
(
-
1.0e+200
)
==
0
);
}
TEST
(
ConvertSat
,
float2int
)
{
FOR_RANGE
(
int32_t
,
exp
,
-
10
,
100
)
{
FOR_RANGE
(
float
,
sig
,
-
256
,
257
)
{
float
f
=
ldexpf
(
sig
,
exp
);
float
integral
;
float
fract
=
modff
(
f
,
&
integral
);
if
(
fract
==
0.5
f
||
fract
==
-
0.5
f
)
continue
;
double
rounded
=
roundf
(
f
);
int64_t
clamped
=
clamp
<
double
>
(
rounded
,
-
128
,
127
);
ASSERT_EQ
(
ConvertSat
<
int8_t
>
(
f
),
clamped
)
<<
" with f = "
<<
f
;
clamped
=
clamp
<
double
>
(
rounded
,
0
,
255
);
ASSERT_EQ
(
ConvertSat
<
uint8_t
>
(
f
),
clamped
)
<<
" with f = "
<<
f
;
clamped
=
clamp
<
double
>
(
rounded
,
-
0x8000
,
0x7fff
);
ASSERT_EQ
(
ConvertSat
<
int16_t
>
(
f
),
clamped
)
<<
" with f = "
<<
f
;
clamped
=
clamp
<
double
>
(
rounded
,
0
,
0xffff
);
ASSERT_EQ
(
ConvertSat
<
uint16_t
>
(
f
),
clamped
)
<<
" with f = "
<<
f
;
clamped
=
clamp
<
double
>
(
rounded
,
int32_t
(
~
0x7fffffff
),
0x7fffffff
);
ASSERT_EQ
(
ConvertSat
<
int32_t
>
(
f
),
clamped
)
<<
" with f = "
<<
f
;
clamped
=
clamp
<
double
>
(
rounded
,
0
,
0xffffffffu
);
ASSERT_EQ
(
ConvertSat
<
uint32_t
>
(
f
),
clamped
)
<<
" with f = "
<<
f
;
}
}
}
TEST
(
ConvertNorm
,
int2int
)
{
EXPECT_EQ
((
ConvertNorm
<
uint8_t
,
uint8_t
>
(
0
)),
0
);
EXPECT_EQ
((
ConvertNorm
<
uint8_t
,
int8_t
>
(
127
)),
255
);
}
TEST
(
ConvertNorm
,
float2int
)
{
EXPECT_EQ
(
ConvertNorm
<
uint8_t
>
(
0.0
f
),
0
);
EXPECT_EQ
(
ConvertNorm
<
uint8_t
>
(
0.499
f
),
127
);
EXPECT_EQ
(
ConvertNorm
<
uint8_t
>
(
1.0
f
),
255
);
EXPECT_EQ
(
ConvertNorm
<
int8_t
>
(
1.0
f
),
127
);
EXPECT_EQ
(
ConvertNorm
<
int8_t
>
(
0.499
f
),
63
);
EXPECT_EQ
(
ConvertNorm
<
int8_t
>
(
-
1.0
f
),
-
127
);
EXPECT_EQ
(
ConvertNorm
<
uint16_t
>
(
0.0
f
),
0
);
EXPECT_EQ
(
ConvertNorm
<
uint16_t
>
(
1.0
f
),
0xffff
);
EXPECT_EQ
(
ConvertNorm
<
int16_t
>
(
1.0
f
),
0x7fff
);
EXPECT_EQ
(
ConvertNorm
<
int16_t
>
(
-
1.0
f
),
-
0x7fff
);
}
TEST
(
ConvertSatNorm
,
float2int
)
{
EXPECT_EQ
(
ConvertSatNorm
<
uint8_t
>
(
2.0
f
),
255
);
EXPECT_EQ
(
ConvertSatNorm
<
uint8_t
>
(
0.499
f
),
127
);
EXPECT_EQ
(
ConvertSatNorm
<
uint8_t
>
(
-
2.0
f
),
0
);
EXPECT_EQ
(
ConvertSatNorm
<
int8_t
>
(
2.0
f
),
127
);
EXPECT_EQ
(
ConvertSatNorm
<
int8_t
>
(
0.499
f
),
63
);
EXPECT_EQ
(
ConvertSatNorm
<
int8_t
>
(
-
2.0
f
),
-
128
);
EXPECT_EQ
(
ConvertSatNorm
<
uint8_t
>
(
0.4
f
/
255
),
0
);
EXPECT_EQ
(
ConvertSatNorm
<
uint8_t
>
(
0.6
f
/
255
),
1
);
EXPECT_EQ
(
ConvertSatNorm
<
int16_t
>
(
2.0
f
),
0x7fff
);
EXPECT_EQ
(
ConvertSatNorm
<
int16_t
>
(
-
2.0
f
),
-
0x8000
);
}
TEST
(
ConvertNorm
,
int2float
)
{
EXPECT_EQ
((
ConvertNorm
<
float
,
uint8_t
>
(
255
)),
1.0
f
);
EXPECT_NEAR
((
ConvertNorm
<
float
,
uint8_t
>
(
127
)),
1.0
f
*
127
/
255
,
1e-7
f
);
EXPECT_EQ
((
ConvertNorm
<
float
,
int8_t
>
(
127
)),
1.0
f
);
EXPECT_NEAR
((
ConvertNorm
<
float
,
int8_t
>
(
64
)),
1.0
f
*
64
/
127
,
1e-7
f
);
}
TEST
(
Clamp1
,
int64_2_float16
)
{
int64_t
big_num
=
0x0FFFFFFFFFFFFFFF
;
EXPECT_EQ
(
static_cast
<
float
>
(
Clamp
<
float16
>
(
big_num
)),
Clamp
<
float16
>
(
Clamp
<
float
>
(
big_num
)));
EXPECT_EQ
(
65504.0
f
,
Clamp
<
float16
>
(
big_num
));
EXPECT_EQ
(
-
65504.0
f
,
Clamp
<
float16
>
(
-
big_num
));
}
TEST
(
Clamp2
,
float16_2_int64
)
{
float16
fp16
=
static_cast
<
float16
>
(
65504.0
f
);
EXPECT_EQ
(
65504
,
Clamp
<
int64_t
>
(
fp16
));
EXPECT_EQ
(
-
65504
,
Clamp
<
int64_t
>
(
-
fp16
));
}
}
// namespace oneflow
\ No newline at end of file
oneflow/core/common/data_type_converter_test_static.h
0 → 100644
View file @
21d47d0e
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_
#define ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_
#include "oneflow/core/common/data_type_converter.h"
namespace
oneflow
{
namespace
{
// fp to int
static_assert
(
NeedsClamp
<
float
,
int8_t
>::
value
,
"Float range exceeds all ints up to 64b"
);
static_assert
(
NeedsClamp
<
float
,
uint8_t
>::
value
,
"Float range exceeds all ints up to 64b"
);
static_assert
(
NeedsClamp
<
float
,
int16_t
>::
value
,
"Float range exceeds all ints up to 64b"
);
static_assert
(
NeedsClamp
<
float
,
uint16_t
>::
value
,
"Float range exceeds all ints up to 64b"
);
static_assert
(
NeedsClamp
<
float
,
int32_t
>::
value
,
"Float range exceeds all ints up to 64b"
);
static_assert
(
NeedsClamp
<
float
,
uint32_t
>::
value
,
"Float range exceeds all ints up to 64b"
);
static_assert
(
NeedsClamp
<
float
,
int64_t
>::
value
,
"Float range exceeds all ints up to 64b"
);
static_assert
(
NeedsClamp
<
float
,
uint64_t
>::
value
,
"Float range exceeds all ints up to 64b"
);
// same size, different signedness
static_assert
(
NeedsClamp
<
int8_t
,
uint8_t
>::
value
,
"Signed <-> unsigned requires clamp"
);
static_assert
(
NeedsClamp
<
uint8_t
,
int8_t
>::
value
,
"Signed <-> unsigned requires clamp"
);
static_assert
(
NeedsClamp
<
int16_t
,
uint16_t
>::
value
,
"Signed <-> unsigned requires clamp"
);
static_assert
(
NeedsClamp
<
uint16_t
,
int16_t
>::
value
,
"Signed <-> unsigned requires clamp"
);
static_assert
(
NeedsClamp
<
int32_t
,
uint32_t
>::
value
,
"Signed <-> unsigned requires clamp"
);
static_assert
(
NeedsClamp
<
uint32_t
,
int32_t
>::
value
,
"Signed <-> unsigned requires clamp"
);
static_assert
(
NeedsClamp
<
int64_t
,
uint64_t
>::
value
,
"Signed <-> unsigned requires clamp"
);
static_assert
(
NeedsClamp
<
uint64_t
,
int64_t
>::
value
,
"Signed <-> unsigned requires clamp"
);
// larger, but unsigned
static_assert
(
NeedsClamp
<
int8_t
,
uint16_t
>::
value
,
"Need to clamp negatives to 0"
);
static_assert
(
NeedsClamp
<
int8_t
,
uint32_t
>::
value
,
"Need to clamp negatives to 0"
);
static_assert
(
NeedsClamp
<
int8_t
,
uint64_t
>::
value
,
"Need to clamp negatives to 0"
);
static_assert
(
NeedsClamp
<
int16_t
,
uint32_t
>::
value
,
"Need to clamp negatives to 0"
);
static_assert
(
NeedsClamp
<
int16_t
,
uint64_t
>::
value
,
"Need to clamp negatives to 0"
);
static_assert
(
NeedsClamp
<
int32_t
,
uint64_t
>::
value
,
"Need to clamp negatives to 0"
);
static_assert
(
!
NeedsClamp
<
int8_t
,
int8_t
>::
value
,
"Clamping not required"
);
static_assert
(
!
NeedsClamp
<
int8_t
,
int16_t
>::
value
,
"Clamping not required"
);
static_assert
(
!
NeedsClamp
<
uint8_t
,
int16_t
>::
value
,
"Clamping not required"
);
static_assert
(
!
NeedsClamp
<
uint8_t
,
uint16_t
>::
value
,
"Clamping not required"
);
static_assert
(
!
NeedsClamp
<
float
,
float
>::
value
,
"Clamping not required"
);
static_assert
(
!
NeedsClamp
<
float
,
double
>::
value
,
"Clamping not required"
);
}
// namespace
}
// namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_
Prev
1
…
22
23
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment