Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_
#define ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/buffer.h"
namespace oneflow {
template<typename T>
class BufferMgr final {
public:
OF_DISALLOW_COPY_AND_MOVE(BufferMgr);
~BufferMgr() = default;
void NewBuffer(const std::string& buffer_name, size_t buffer_size) {
CHECK(name2buffer_.emplace(buffer_name, std::make_unique<Buffer<T>>(buffer_size)).second);
}
Buffer<T>* Get(const std::string& buffer_name) const {
const auto& iter = name2buffer_.find(buffer_name);
CHECK(iter != name2buffer_.end()) << "buffer_name: " << buffer_name;
return iter->second.get();
}
private:
friend class Singleton<BufferMgr>;
BufferMgr() = default;
HashMap<std::string, std::unique_ptr<Buffer<T>>> name2buffer_;
};
static const std::string kBufferNameGlobalWaitJobId = "GlobalWaitJobId";
inline std::string GetCallbackNotifierBufferName(const std::string& job_name) {
static const std::string prefix = "CallbackNotifier-";
return prefix + job_name;
}
inline std::string GetInputCriticalSectionWaitBufferName(const std::string& job_name) {
static const std::string prefix = "InputCriticalSectionWait-";
return prefix + job_name;
}
inline std::string GetInputCriticalSectionCallbackBufferName(const std::string& job_name) {
static const std::string prefix = "InputCriticalSectionCallback-";
return prefix + job_name;
}
inline std::string GetOutputCriticalSectionWaitBufferName(const std::string& job_name) {
static const std::string prefix = "OutputCriticalSectionWait-";
return prefix + job_name;
}
inline std::string GetOutputCriticalSectionCallbackBufferName(const std::string& job_name) {
static const std::string prefix = "OutputCriticalSectionCallback-";
return prefix + job_name;
}
inline std::string GetForeignInputBufferName(const std::string& job_name) {
static const std::string prefix = "ForeignInput-";
return prefix + job_name;
}
inline std::string GetForeignOutputBufferName(const std::string& job_name) {
static const std::string prefix = "ForeignOutput-";
return prefix + job_name;
}
inline std::string GetInputBufferName(const std::string& job_name, const std::string& op_name) {
static const std::string prefix = "ForeignInput-";
return prefix + job_name + "-" + op_name;
}
inline std::string GetOutputBufferName(const std::string& job_name, const std::string& op_name) {
static const std::string prefix = "ForeignOutput-";
return prefix + job_name + "-" + op_name;
}
inline std::string GetSourceTickBufferName(const std::string& job_name) {
static const std::string prefix = "SourceTick-";
return prefix + job_name;
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_BUFFER_MANAGER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/cached_caller.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
bool IsThreadLocalCacheEnabled() {
if (Singleton<ResourceDesc, ForSession>::Get() == nullptr) { return true; }
return Singleton<ResourceDesc, ForSession>::Get()->enable_thread_local_cache();
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CACHED_CALLER_H_
#define ONEFLOW_CORE_COMMON_CACHED_CALLER_H_
#include <list>
#include <tuple>
#include <thread>
#include "oneflow/core/common/function_traits.h"
#include "oneflow/core/common/hash_eq_trait_ptr.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/tuple_hash.h"
// gcc 11 falsely reports error:
// ‘void operator delete(void*, std::size_t)’ called on unallocated object ‘cache’
// However, `DeleteAndClear` is only called after `cache` is allocated in
// if (cache == nullptr) block.
// The reason not to use #pragma GCC diagnostic push/pop is that gcc reports
// the error on the caller of `ThreadLocalCachedCall`.
// TODO: replace ThreadLocalCachedCall with ThreadLocalCached decorator?
#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 11
#pragma GCC diagnostic ignored "-Wfree-nonheap-object"
#endif
namespace oneflow {
template<typename T>
void DeleteAndClear(T** ptr, size_t obj_cnt) {
static const size_t kThreshold = 4096;
if (obj_cnt <= kThreshold) {
delete ptr;
} else {
std::thread([](T* ptr) { delete ptr; }, *ptr);
}
*ptr = nullptr;
}
bool IsThreadLocalCacheEnabled();
template<
typename F, typename Ret = typename function_traits<F>::return_type,
typename RawArg = typename std::tuple_element<0, typename function_traits<F>::args_type>::type,
typename Arg = typename std::remove_const<typename std::remove_reference<RawArg>::type>::type>
Ret ThreadLocalCachedCall(size_t max_size, F f, const Arg& arg) {
if (IsThreadLocalCacheEnabled() == false) { return f(arg); }
using HashMap = std::unordered_map<HashEqTraitPtr<const Arg>, Ret>;
using KeyStorage = std::list<std::unique_ptr<Arg>>;
static thread_local HashMap* cache = nullptr;
static thread_local KeyStorage* key_storage = nullptr;
if (cache != nullptr && cache->size() >= max_size) {
DeleteAndClear(&cache, cache->size());
DeleteAndClear(&key_storage, cache->size());
}
if (cache == nullptr) {
cache = new HashMap();
key_storage = new KeyStorage();
}
size_t hash_value = std::hash<Arg>()(arg);
{
HashEqTraitPtr<const Arg> ptr_wrapper(&arg, hash_value);
const auto& iter = cache->find(ptr_wrapper);
if (iter != cache->end()) { return iter->second; }
}
Arg* new_arg = new Arg(arg);
key_storage->emplace_back(new_arg);
HashEqTraitPtr<const Arg> ptr_wrapper(new_arg, hash_value);
return cache->emplace(ptr_wrapper, f(*new_arg)).first->second;
}
template<
typename F, typename Ret = typename function_traits<F>::return_type,
typename RawArg = typename std::tuple_element<0, typename function_traits<F>::args_type>::type,
typename Arg = typename std::remove_const<typename std::remove_reference<RawArg>::type>::type>
std::function<Ret(const Arg&)> WithResultCached(F f) {
auto cache = std::make_shared<std::unordered_map<Arg, Ret>>();
return [cache, f](const Arg& arg) -> Ret {
const auto& iter = cache->find(arg);
if (iter != cache->end()) { return iter->second; }
return cache->emplace(arg, f(arg)).first->second;
};
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CACHED_CALLER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CBLAS_H_
#define ONEFLOW_CORE_COMMON_CBLAS_H_
#include <stddef.h>
/*
* Enumerated and derived types
*/
#define CBLAS_INDEX size_t /* this may vary between platforms */
enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 };
enum CBLAS_TRANSPOSE { CblasNoTrans = 111, CblasTrans = 112, CblasConjTrans = 113 };
enum CBLAS_UPLO { CblasUpper = 121, CblasLower = 122 };
enum CBLAS_DIAG { CblasNonUnit = 131, CblasUnit = 132 };
enum CBLAS_SIDE { CblasLeft = 141, CblasRight = 142 };
#ifdef __cplusplus
extern "C" {
#endif
/*
* ===========================================================================
* Prototypes for level 1 BLAS functions (complex are recast as routines)
* ===========================================================================
*/
float cblas_sdsdot(const int N, const float alpha, const float* X, const int incX, const float* Y,
const int incY);
double cblas_dsdot(const int N, const float* X, const int incX, const float* Y, const int incY);
float cblas_sdot(const int N, const float* X, const int incX, const float* Y, const int incY);
double cblas_ddot(const int N, const double* X, const int incX, const double* Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void* X, const int incX, const void* Y, const int incY,
void* dotu);
void cblas_cdotc_sub(const int N, const void* X, const int incX, const void* Y, const int incY,
void* dotc);
void cblas_zdotu_sub(const int N, const void* X, const int incX, const void* Y, const int incY,
void* dotu);
void cblas_zdotc_sub(const int N, const void* X, const int incX, const void* Y, const int incY,
void* dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float* X, const int incX);
float cblas_sasum(const int N, const float* X, const int incX);
double cblas_dnrm2(const int N, const double* X, const int incX);
double cblas_dasum(const int N, const double* X, const int incX);
float cblas_scnrm2(const int N, const void* X, const int incX);
float cblas_scasum(const int N, const void* X, const int incX);
double cblas_dznrm2(const int N, const void* X, const int incX);
double cblas_dzasum(const int N, const void* X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float* X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double* X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void* X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void* X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float* X, const int incX, float* Y, const int incY);
void cblas_scopy(const int N, const float* X, const int incX, float* Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float* X, const int incX, float* Y,
const int incY);
void cblas_dswap(const int N, double* X, const int incX, double* Y, const int incY);
void cblas_dcopy(const int N, const double* X, const int incX, double* Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double* X, const int incX, double* Y,
const int incY);
void cblas_cswap(const int N, void* X, const int incX, void* Y, const int incY);
void cblas_ccopy(const int N, const void* X, const int incX, void* Y, const int incY);
void cblas_caxpy(const int N, const void* alpha, const void* X, const int incX, void* Y,
const int incY);
void cblas_zswap(const int N, void* X, const int incX, void* Y, const int incY);
void cblas_zcopy(const int N, const void* X, const int incX, void* Y, const int incY);
void cblas_zaxpy(const int N, const void* alpha, const void* X, const int incX, void* Y,
const int incY);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float* a, float* b, float* c, float* s);
void cblas_srotmg(float* d1, float* d2, float* b1, const float b2, float* P);
void cblas_srot(const int N, float* X, const int incX, float* Y, const int incY, const float c,
const float s);
void cblas_srotm(const int N, float* X, const int incX, float* Y, const int incY, const float* P);
void cblas_drotg(double* a, double* b, double* c, double* s);
void cblas_drotmg(double* d1, double* d2, double* b1, const double b2, double* P);
void cblas_drot(const int N, double* X, const int incX, double* Y, const int incY, const double c,
const double s);
void cblas_drotm(const int N, double* X, const int incX, double* Y, const int incY,
const double* P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float* X, const int incX);
void cblas_dscal(const int N, const double alpha, double* X, const int incX);
void cblas_cscal(const int N, const void* alpha, void* X, const int incX);
void cblas_zscal(const int N, const void* alpha, void* X, const int incX);
void cblas_csscal(const int N, const float alpha, void* X, const int incX);
void cblas_zdscal(const int N, const double alpha, void* X, const int incX);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,
const int N, const float alpha, const float* A, const int lda, const float* X,
const int incX, const float beta, float* Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,
const int N, const int KL, const int KU, const float alpha, const float* A,
const int lda, const float* X, const int incX, const float beta, float* Y,
const int incY);
void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const float* A, const int lda, float* X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const int K, const float* A, const int lda, float* X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const float* Ap, float* X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const float* A, const int lda, float* X, const int incX);
void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const int K, const float* A, const int lda, float* X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const float* Ap, float* X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,
const int N, const double alpha, const double* A, const int lda, const double* X,
const int incX, const double beta, double* Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,
const int N, const int KL, const int KU, const double alpha, const double* A,
const int lda, const double* X, const int incX, const double beta, double* Y,
const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const double* A, const int lda, double* X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const int K, const double* A, const int lda, double* X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const double* Ap, double* X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const double* A, const int lda, double* X, const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const int K, const double* A, const int lda, double* X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const double* Ap, double* X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,
const int N, const void* alpha, const void* A, const int lda, const void* X,
const int incX, const void* beta, void* Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,
const int N, const int KL, const int KU, const void* alpha, const void* A,
const int lda, const void* X, const int incX, const void* beta, void* Y,
const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const void* A, const int lda, void* X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const int K, const void* A, const int lda, void* X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const void* Ap, void* X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const void* A, const int lda, void* X, const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const int K, const void* A, const int lda, void* X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const void* Ap, void* X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,
const int N, const void* alpha, const void* A, const int lda, const void* X,
const int incX, const void* beta, void* Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const int M,
const int N, const int KL, const int KU, const void* alpha, const void* A,
const int lda, const void* X, const int incX, const void* beta, void* Y,
const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const void* A, const int lda, void* X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const int K, const void* A, const int lda, void* X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const void* Ap, void* X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const void* A, const int lda, void* X, const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const int K, const void* A, const int lda, void* X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, const int N,
const void* Ap, void* X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const float alpha, const float* A, const int lda, const float* X, const int incX,
const float beta, float* Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K,
const float alpha, const float* A, const int lda, const float* X, const int incX,
const float beta, float* Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const float alpha, const float* Ap, const float* X, const int incX,
const float beta, float* Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N, const float alpha,
const float* X, const int incX, const float* Y, const int incY, float* A,
const int lda);
void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const float alpha, const float* X, const int incX, float* A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const float alpha, const float* X, const int incX, float* Ap);
void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const float alpha, const float* X, const int incX, const float* Y, const int incY,
float* A, const int lda);
void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const float alpha, const float* X, const int incX, const float* Y, const int incY,
float* A);
void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const double alpha, const double* A, const int lda, const double* X,
const int incX, const double beta, double* Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K,
const double alpha, const double* A, const int lda, const double* X,
const int incX, const double beta, double* Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const double alpha, const double* Ap, const double* X, const int incX,
const double beta, double* Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N, const double alpha,
const double* X, const int incX, const double* Y, const int incY, double* A,
const int lda);
void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const double alpha, const double* X, const int incX, double* A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const double alpha, const double* X, const int incX, double* Ap);
void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const double alpha, const double* X, const int incX, const double* Y,
const int incY, double* A, const int lda);
void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const double alpha, const double* X, const int incX, const double* Y,
const int incY, double* A);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void* alpha, const void* A, const int lda, const void* X, const int incX,
const void* beta, void* Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K,
const void* alpha, const void* A, const int lda, const void* X, const int incX,
const void* beta, void* Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void* alpha, const void* Ap, const void* X, const int incX, const void* beta,
void* Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha,
const void* X, const int incX, const void* Y, const int incY, void* A,
const int lda);
void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha,
const void* X, const int incX, const void* Y, const int incY, void* A,
const int lda);
void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const float alpha, const void* X, const int incX, void* A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const float alpha, const void* X, const int incX, void* A);
void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void* alpha, const void* X, const int incX, const void* Y, const int incY,
void* A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void* alpha, const void* X, const int incX, const void* Y, const int incY,
void* Ap);
void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void* alpha, const void* A, const int lda, const void* X, const int incX,
const void* beta, void* Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, const int K,
const void* alpha, const void* A, const int lda, const void* X, const int incX,
const void* beta, void* Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void* alpha, const void* Ap, const void* X, const int incX, const void* beta,
void* Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha,
const void* X, const int incX, const void* Y, const int incY, void* A,
const int lda);
void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N, const void* alpha,
const void* X, const int incX, const void* Y, const int incY, void* A,
const int lda);
void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const double alpha, const void* X, const int incX, void* A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const double alpha, const void* X, const int incX, void* A);
void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void* alpha, const void* X, const int incX, const void* Y, const int incY,
void* A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void* alpha, const void* X, const int incX, const void* Y, const int incY,
void* Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
const float alpha, const float* A, const int lda, const float* B, const int ldb,
const float beta, float* C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N, const float alpha,
const float* A, const int lda, const float* B, const int ldb, const float beta,
float* C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const float alpha,
const float* A, const int lda, const float beta, float* C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const float alpha,
const float* A, const int lda, const float* B, const int ldb, const float beta,
float* C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N, const float alpha,
const float* A, const int lda, float* B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N, const float alpha,
const float* A, const int lda, float* B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
const double alpha, const double* A, const int lda, const double* B, const int ldb,
const double beta, double* C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N, const double alpha,
const double* A, const int lda, const double* B, const int ldb, const double beta,
double* C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const double alpha,
const double* A, const int lda, const double beta, double* C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const double alpha,
const double* A, const int lda, const double* B, const int ldb, const double beta,
double* C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N, const double alpha,
const double* A, const int lda, double* B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N, const double alpha,
const double* A, const int lda, double* B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
const void* alpha, const void* A, const int lda, const void* B, const int ldb,
const void* beta, void* C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha,
const void* A, const int lda, const void* B, const int ldb, const void* beta,
void* C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,
const void* A, const int lda, const void* beta, void* C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,
const void* A, const int lda, const void* B, const int ldb, const void* beta,
void* C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha,
const void* A, const int lda, void* B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha,
const void* A, const int lda, void* B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
const void* alpha, const void* A, const int lda, const void* B, const int ldb,
const void* beta, void* C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha,
const void* A, const int lda, const void* B, const int ldb, const void* beta,
void* C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,
const void* A, const int lda, const void* beta, void* C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,
const void* A, const int lda, const void* B, const int ldb, const void* beta,
void* C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha,
const void* A, const int lda, void* B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N, const void* alpha,
const void* A, const int lda, void* B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha,
const void* A, const int lda, const void* B, const int ldb, const void* beta,
void* C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const float alpha,
const void* A, const int lda, const float beta, void* C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,
const void* A, const int lda, const void* B, const int ldb, const float beta,
void* C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N, const void* alpha,
const void* A, const int lda, const void* B, const int ldb, const void* beta,
void* C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const double alpha,
const void* A, const int lda, const double beta, void* C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K, const void* alpha,
const void* A, const int lda, const void* B, const int ldb, const double beta,
void* C, const int ldc);
void cblas_xerbla(int p, const char* rout, const char* form, ...);
#ifdef __cplusplus
}
#endif
#endif
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CHANNEL_H_
#define ONEFLOW_CORE_COMMON_CHANNEL_H_
#include "oneflow/core/common/util.h"
namespace oneflow {
enum ChannelStatus { kChannelStatusSuccess = 0, kChannelStatusErrorClosed };
template<typename T>
class Channel final {
public:
OF_DISALLOW_COPY_AND_MOVE(Channel);
Channel() : is_closed_(false) {}
~Channel() = default;
template<typename U>
ChannelStatus Send(U&& item);
ChannelStatus Receive(T* item);
ChannelStatus ReceiveMany(std::queue<T>* items);
void Close();
private:
std::queue<T> queue_;
std::mutex mutex_;
bool is_closed_;
std::condition_variable cond_;
};
template<typename T>
template<typename U>
ChannelStatus Channel<T>::Send(U&& item) {
bool notify;
{
std::unique_lock<std::mutex> lock(mutex_);
if (is_closed_) { return kChannelStatusErrorClosed; }
notify = queue_.empty();
queue_.push(std::forward<U>(item));
}
if (notify) { cond_.notify_one(); }
return kChannelStatusSuccess;
}
template<typename T>
ChannelStatus Channel<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mutex_);
cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; });
if (queue_.empty()) { return kChannelStatusErrorClosed; }
*item = std::move(queue_.front());
queue_.pop();
return kChannelStatusSuccess;
}
template<typename T>
ChannelStatus Channel<T>::ReceiveMany(std::queue<T>* items) {
std::unique_lock<std::mutex> lock(mutex_);
cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; });
if (queue_.empty()) { return kChannelStatusErrorClosed; }
while (!queue_.empty()) {
items->push(std::move(queue_.front()));
queue_.pop();
}
return kChannelStatusSuccess;
}
template<typename T>
void Channel<T>::Close() {
std::unique_lock<std::mutex> lock(mutex_);
is_closed_ = true;
cond_.notify_all();
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CHANNEL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "gtest/gtest.h"
#include "oneflow/core/common/channel.h"
#include "oneflow/core/common/range.h"
namespace oneflow {
void CallFromSenderThread(Channel<int>* channel, Range range) {
for (int i = range.begin(); i < range.end(); ++i) {
if (channel->Send(i) != kChannelStatusSuccess) { break; }
}
}
void CallFromReceiverThread(std::vector<int>* visit, Channel<int>* channel) {
int num = -1;
int* num_ptr = &num;
while (channel->Receive(num_ptr) == kChannelStatusSuccess) { ++visit->at(*num_ptr); }
}
TEST(Channel, 30sender40receiver) {
Channel<int> channel;
std::vector<std::thread> senders;
std::vector<std::thread> receivers;
int sender_num = 30;
int receiver_num = 40;
int range_num = 200;
std::vector<std::vector<int>> visits;
for (int i = 0; i < receiver_num; ++i) {
std::vector<int> visit_i;
for (int j = 0; j < range_num; j++) { visit_i.emplace_back(0); }
visits.emplace_back(visit_i);
}
for (int i = 0; i < sender_num; ++i) {
senders.emplace_back(CallFromSenderThread, &channel, Range(0, range_num));
}
for (int i = 0; i < receiver_num; ++i) {
receivers.emplace_back(CallFromReceiverThread, &visits[i], &channel);
}
for (std::thread& this_thread : senders) { this_thread.join(); }
channel.Close();
for (std::thread& this_thread : receivers) { this_thread.join(); }
for (int i = 0; i < range_num; ++i) {
int visit_count = 0;
for (int j = 0; j < receiver_num; j++) { visit_count += visits[j][i]; }
ASSERT_EQ(visit_count, sender_num);
}
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cstdlib>
#include <type_traits>
#include "oneflow/core/common/just.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/env_var/debug_mode.h"
namespace oneflow {
bool IsEnvEnabled(int32_t check_level) {
static const int env_check_level = ParseIntegerFromEnv("ONEFOW_CHECK_LEVEL", -1);
static const bool env_debug_mode = IsInDebugMode();
return env_debug_mode || env_check_level >= check_level;
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_
#define ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_
namespace oneflow {
bool IsEnvEnabled(int32_t check_level);
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CHECK_LEVEL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CONSTANT_H_
#define ONEFLOW_CORE_COMMON_CONSTANT_H_
#include <string>
namespace oneflow {
static const int64_t kInvalidSessionId = -1;
static const std::string kNoPassTag = "";
static const std::string kMainOp = "main_op";
static const int64_t kMaxSplitAxis = 6;
static const std::string kAsymmetricCodeErrorMsg =
"Maybe executing different code in different ranks, please check if the code is branched and "
"operates on the global tensor.";
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CONSTANT_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_
#define ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_
#include <vector>
#include "oneflow/core/common/hash_container.h"
#include "oneflow/core/common/type_traits.h"
#include "oneflow/core/common/maybe.h"
namespace oneflow {
template<typename MapT, typename KeyT, typename U>
scalar_or_const_ref_t<typename MapT::mapped_type> MapAt(const MapT& map, const KeyT& key,
const U& default_val) {
const auto& iter = map.find(key);
if (iter == map.end()) { return default_val; }
return iter->second;
}
template<typename MapT, typename KeyT>
Maybe<scalar_or_const_ref_t<typename MapT::mapped_type>> MapAt(const MapT& map, const KeyT& key) {
const auto& iter = map.find(key);
CHECK_OR_RETURN(iter != map.end());
return iter->second;
}
template<typename MapT, typename KeyT>
Maybe<typename MapT::mapped_type&> MapAt(MapT& map, const KeyT& key) {
const auto& iter = map.find(key);
CHECK_OR_RETURN(iter != map.end());
return iter->second;
}
template<typename VecT>
Maybe<scalar_or_const_ref_t<typename VecT::value_type>> VectorAt(const VecT& vec,
typename VecT::size_type index) {
CHECK_LT_OR_RETURN(index, vec.size());
return vec[index];
}
template<typename VecT>
Maybe<typename VecT::value_type&> VectorAt(VecT& vec, typename VecT::size_type index) {
static_assert(!std::is_same<typename VecT::value_type, bool>::value,
"VectorAt(vector<bool>&, size_t) is not supported.");
CHECK_LT_OR_RETURN(index, vec.size());
return vec[index];
}
template<>
inline Maybe<bool> VectorAt(const std::vector<bool>& vec,
typename std::vector<bool>::size_type index) {
CHECK_LT_OR_RETURN(index, vec.size());
// convert vector bool proxy to bool
return static_cast<bool>(vec[index]);
}
template<typename T>
std::string Join(const T& con, const std::string& delimiter) {
std::ostringstream os;
auto b = begin(con), e = end(con);
if (b != e) {
std::copy(b, prev(e), std::ostream_iterator<typename T::value_type>(os, delimiter));
b = prev(e);
}
if (b != e) { os << *b; }
return os.str();
}
template<typename T>
using SmallSet = std::vector<T>;
template<typename T>
std::pair<typename SmallSet<T>::iterator, bool> SmallSetInsert(SmallSet<T>* vec, const T& elem) {
for (auto iter = vec->begin(); iter != vec->end(); ++iter) {
if (*iter == elem) { return std::make_pair(iter, false); }
}
vec->push_back(elem);
return std::make_pair(--vec->end(), true);
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "gtest/gtest.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
namespace test {
TEST(VectorAt, write_int_vector) {
std::vector<int> vec = {1, 2, 3, 4, 5};
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)), 2);
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 3)), 4);
CHECK_JUST(VectorAt(vec, 1)) = 6;
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)), 6);
CHECK_JUST(VectorAt(vec, 3)) = 8;
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 3)), 8);
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 0)), 1);
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 2)), 3);
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 4)), 5);
}
namespace {
class A {
public:
explicit A(int a) : a(a) {}
int a;
};
} // namespace
TEST(VectorAt, write_custom_class_vector) {
std::vector<A> vec = {A(1), A(2)};
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 0)).a, 1);
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)).a, 2);
CHECK_JUST(VectorAt(vec, 0)) = A(3);
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 0)).a, 3);
CHECK_JUST(VectorAt(vec, 1)) = A(4);
EXPECT_EQ(CHECK_JUST(VectorAt(vec, 1)).a, 4);
}
} // namespace test
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
#define ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
#if __cplusplus < 201703L
#include <functional>
#include <numeric>
namespace std {
// a sequential version of inclusive_scan and exclusive_scan
template<class InputIt, class OutputIt>
OutputIt inclusive_scan(InputIt first, InputIt last, OutputIt d_first) {
return partial_sum(first, last, d_first);
}
template<class InputIt, class OutputIt, class BinaryOperation>
OutputIt inclusive_scan(InputIt first, InputIt last, OutputIt d_first, BinaryOperation binary_op) {
return partial_sum(first, last, d_first, binary_op);
}
template<class InputIt, class OutputIt, class BinaryOperation, class T>
OutputIt inclusive_scan(InputIt first, InputIt last, OutputIt d_first, BinaryOperation binary_op,
T init) {
// Based on https://en.cppreference.com/w/cpp/algorithm/partial_sum
if (first == last) return d_first;
typename std::iterator_traits<InputIt>::value_type sum = op(*first, init);
*d_first = sum;
while (++first != last) {
sum = binary_op(sum, *first);
*++d_first = sum;
}
return ++d_first;
}
template<class InputIt, class OutputIt, class T, class BinaryOperation>
OutputIt exclusive_scan(InputIt first, InputIt last, OutputIt d_first, T init,
BinaryOperation binary_op) {
if (first == last) return d_first;
typename std::iterator_traits<InputIt>::value_type sum = init;
*d_first = sum;
first--;
last--;
while (++first != last) {
sum = binary_op(sum, *first);
*++d_first = sum;
}
return ++d_first;
}
template<class InputIt, class OutputIt, class T>
OutputIt exclusive_scan(InputIt first, InputIt last, OutputIt d_first, T init) {
return exclusive_scan(first, last, d_first, init, std::plus<>());
}
} // namespace std
#endif
#endif // ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <gtest/gtest.h>
#include <functional>
#include <iostream>
#include <iterator>
#include <vector>
#include "oneflow/core/common/cplusplus_17.h"
namespace oneflow {
namespace test {
TEST(Scan, scan) {
std::vector<int> data{3, 1, 4, 1, 5, 9, 2, 6};
std::vector<int> output;
std::exclusive_scan(data.begin(), data.end(), std::back_insert_iterator<std::vector<int>>(output),
0);
std::vector<int> ref_output = {0, 3, 4, 8, 9, 14, 23, 25};
EXPECT_EQ(output, ref_output);
output.clear();
std::inclusive_scan(data.begin(), data.end(),
std::back_insert_iterator<std::vector<int>>(output));
ref_output = {3, 4, 8, 9, 14, 23, 25, 31};
EXPECT_EQ(output, ref_output);
output.clear();
std::exclusive_scan(data.begin(), data.end(), std::back_insert_iterator<std::vector<int>>(output),
1, std::multiplies<>{});
ref_output = {1, 3, 3, 12, 12, 60, 540, 1080};
EXPECT_EQ(output, ref_output);
output.clear();
std::inclusive_scan(data.begin(), data.end(), std::back_insert_iterator<std::vector<int>>(output),
std::multiplies<>{});
ref_output = {3, 3, 12, 12, 60, 540, 1080, 6480};
EXPECT_EQ(output, ref_output);
output.clear();
std::exclusive_scan(data.rbegin(), data.rend(),
std::back_insert_iterator<std::vector<int>>(output), 1, std::multiplies<>{});
ref_output = {1, 6, 12, 108, 540, 540, 2160, 2160};
EXPECT_EQ(output, ref_output);
output.clear();
}
} // namespace test
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_
#define ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_
#include <glog/logging.h>
#define likely GOOGLE_PREDICT_TRUE
#define unlikely GOOGLE_PREDICT_FALSE
#endif // ONEFLOW_CORE_COMMON_CPP_ATTRIBUTE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/tensor_buffer.h"
namespace oneflow {
bool IsBoolDataType(DataType data_type) {
switch (data_type) {
#define BOOL_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE(BOOL_CASE, BOOL_DATA_TYPE_SEQ)
default: return false;
}
#undef BOOL_CASE
}
bool IsIntegralDataType(DataType data_type) {
switch (data_type) {
#define INTEGRAL_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE(INTEGRAL_CASE, INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ)
default: return false;
}
#undef INTEGRAL_CASE
}
bool IsFloatingDataType(DataType data_type) {
switch (data_type) {
#define FLOATING_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE(FLOATING_CASE, FLOATING_DATA_TYPE_SEQ)
default: return false;
}
#undef FLOATING_CASE
}
bool IsPODDataType(DataType data_type) {
switch (data_type) {
#define POD_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE(POD_CASE, POD_DATA_TYPE_SEQ)
default: return false;
}
#undef POD_CASE
}
bool IsPODAndHalfDataType(DataType data_type) {
switch (data_type) {
#define POD_AND_HALF_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE(POD_AND_HALF_CASE, POD_AND_HALF_DATA_TYPE_SEQ)
default: return false;
}
#undef POD_AND_HALF_CASE
}
bool IsIndexDataType(DataType data_type) {
switch (data_type) {
#define INDEX_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE(INDEX_CASE, INDEX_DATA_TYPE_SEQ)
default: return false;
}
#undef INDEX_CASE
}
bool IsSupportRequireGradDataType(DataType data_type) {
switch (data_type) {
#define REQUIRE_GRAD_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE(REQUIRE_GRAD_CASE, FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ)
default: return false;
}
#undef REQUIRE_GRAD_CASE
}
bool NotSupportBoxingDataType(DataType data_type) {
switch (data_type) {
#define NO_BOXING_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE(NO_BOXING_CASE, NO_BOXING_DATA_TYPE_SEQ)
default: return false;
}
#undef NO_BOXING_CASE
}
size_t GetSizeOfDataType(DataType data_type) {
switch (data_type) {
// 8-bit
case kChar: return 1;
case kInt8: return 1;
case kUInt8: return 1;
case kBool: return 1;
// 16-bit
case kInt16: return 2;
case kUInt16: return 2;
case kFloat16: return 2;
case kBFloat16: return 2;
// 32-bit
case kInt32: return 4;
case kUInt32: return 4;
case kFloat: return 4;
case kComplex32: return 4;
// 64-bit
case kInt64: return 8;
case kUInt64: return 8;
case kDouble: return 8;
case kComplex64: return 8;
// 128-bit
case kInt128: return 16;
case kUInt128: return 16;
case kComplex128: return 16;
// non pod
case kOFRecord: return sizeof(OFRecord);
case kTensorBuffer: return sizeof(TensorBuffer);
default: LOG(FATAL) << "invalid data_type: " << DataType_Name(data_type);
}
}
namespace {
void CheckDataType() {
static_assert(sizeof(int8_t) == sizeof(char), "sizeof(int8_t) != sizeof(char)");
static_assert(sizeof(int16_t) == sizeof(short), "sizeof(int16_t) != sizeof(short)");
static_assert(sizeof(int32_t) == sizeof(int), "sizeof(int32_t) != sizeof(int)");
static_assert(sizeof(int64_t) == sizeof(long long), "sizeof(int64_t) != sizeof(long long)");
#if defined(WITH_CUDA)
#define CHECK_DEVICE_FP16(get_val) \
do { \
float16 host_fp16 = get_val<float16>(); \
half device_fp16 = get_val<half>(); \
CHECK_EQ(*(uint16_t*)&host_fp16, *(uint16_t*)&device_fp16); \
} while (0)
CHECK_DEVICE_FP16(GetZeroVal);
CHECK_DEVICE_FP16(GetOneVal);
CHECK_DEVICE_FP16(GetMaxVal);
CHECK_DEVICE_FP16(GetMinVal);
#undef CHECK_DEVICE_FP16
#endif
#if defined(WITH_ROCM)
#define CHECK_DEVICE_FP16(get_val) \
do { \
float16 host_fp16 = get_val<float16>(); \
half device_fp16 = get_val<half>(); \
CHECK_EQ(*(uint16_t*)&host_fp16, *(uint16_t*)&device_fp16); \
} while (0)
CHECK_DEVICE_FP16(GetZeroVal);
CHECK_DEVICE_FP16(GetOneVal);
CHECK_DEVICE_FP16(GetMaxVal);
CHECK_DEVICE_FP16(GetMinVal);
#undef CHECK_DEVICE_FP16
#endif
#define CHECK_MAX_VAL(T, limit_value) CHECK_EQ(GetMaxVal<T>(), std::numeric_limits<T>::max());
OF_PP_FOR_EACH_TUPLE(CHECK_MAX_VAL, MAX_VAL_SEQ);
#undef CHECK_MAX_VAL
#define CHECK_MIN_VAL(T, limit_value) CHECK_EQ(GetMinVal<T>(), std::numeric_limits<T>::lowest());
OF_PP_FOR_EACH_TUPLE(CHECK_MIN_VAL, MIN_VAL_SEQ);
#undef CHECK_MIN_VAL
}
COMMAND(CheckDataType());
} // namespace
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_H_
#define ONEFLOW_CORE_COMMON_DATA_TYPE_H_
#include <cfloat>
#include <type_traits>
#if defined(WITH_CUDA)
#include <cuda_fp16.h>
#endif
#if defined(WITH_ROCM)
#include <hip/hip_fp16.h>
#endif
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/data_type_seq.h"
#include "oneflow/core/record/record.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/device_type.h"
#include <half.hpp>
namespace oneflow {
typedef half_float::half float16;
template<typename T>
struct IsFloat16;
template<>
struct IsFloat16<float16> : std::true_type {};
#ifdef WITH_CUDA
template<>
struct IsFloat16<half> : std::true_type {};
#endif // WITH_CUDA
#ifdef WITH_ROCM
template<>
struct IsFloat16<half> : std::true_type {};
#endif // WITH_ROCM
template<typename T>
struct IsFloat16 : std::false_type {};
// Type Trait: IsFloating
template<typename T>
struct IsFloating : std::integral_constant<bool, false> {};
#define SPECIALIZE_TRUE_FLOATING(type_cpp, type_proto) \
template<> \
struct IsFloating<type_cpp> : std::integral_constant<bool, true> {};
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_FLOATING, FLOATING_DATA_TYPE_SEQ);
#undef SPECIALIZE_TRUE_FLOATING
// Type Trait: IsIntegral
template<typename T>
struct IsIntegral : std::integral_constant<bool, false> {};
#define SPECIALIZE_TRUE_INTEGRAL(type_cpp, type_proto) \
template<> \
struct IsIntegral<type_cpp> : std::integral_constant<bool, true> {};
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_INTEGRAL, INT_DATA_TYPE_SEQ);
#undef SPECIALIZE_TRUE_INTEGRAL
// Type Trait: IsUnsignedIntegral
template<typename T>
struct IsUnsignedIntegral : std::integral_constant<bool, false> {};
#define SPECIALIZE_TRUE_INTEGRAL(type_cpp, type_proto) \
template<> \
struct IsUnsignedIntegral<type_cpp> : std::integral_constant<bool, true> {};
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_INTEGRAL, UNSIGNED_INT_DATA_TYPE_SEQ);
#undef SPECIALIZE_TRUE_INTEGRAL
// Type Trait: GetDataType
template<typename T, typename T2 = void>
struct GetDataType;
template<>
struct GetDataType<void> : std::integral_constant<DataType, DataType::kChar> {};
#define SPECIALIZE_GET_DATA_TYPE(type_cpp, type_proto) \
template<> \
struct GetDataType<type_cpp> : std::integral_constant<DataType, type_proto> {}; \
inline type_cpp GetTypeByDataType(std::integral_constant<DataType, type_proto>) { return {}; }
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_DATA_TYPE, ALL_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ);
#undef SPECIALIZE_GET_DATA_TYPE
template<typename T>
struct GetDataType<T, typename std::enable_if<IsFloat16<T>::value>::type>
: std::integral_constant<DataType, DataType::kFloat16> {};
template<DataType type>
using DataTypeToType = decltype(GetTypeByDataType(std::integral_constant<DataType, type>{}));
#if defined(__CUDACC__)
#define OF_DEVICE_FUNC __device__ __host__ __forceinline__
#elif defined(__HIPCC__)
#define OF_DEVICE_FUNC __device__ __host__ __forceinline__
#else
#define OF_DEVICE_FUNC inline
#endif
template<typename T, typename std::enable_if<!IsFloat16<T>::value>::type* = nullptr>
OF_DEVICE_FUNC T GetZeroVal() {
return static_cast<T>(0);
}
template<typename T, typename std::enable_if<!IsFloat16<T>::value>::type* = nullptr>
OF_DEVICE_FUNC T GetOneVal() {
return static_cast<T>(1);
}
template<typename T, typename std::enable_if<!IsFloat16<T>::value>::type* = nullptr>
OF_DEVICE_FUNC T GetMinVal();
template<typename T, typename std::enable_if<!IsFloat16<T>::value>::type* = nullptr>
OF_DEVICE_FUNC T GetMaxVal();
#ifdef __APPLE__
#define APPLE_MAX_VAL_SEQ OF_PP_MAKE_TUPLE_SEQ(unsigned long, ULONG_MAX)
#else
#define APPLE_MAX_VAL_SEQ
#endif
#define MAX_VAL_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int8_t, INT8_MAX) \
OF_PP_MAKE_TUPLE_SEQ(int16_t, INT16_MAX) \
OF_PP_MAKE_TUPLE_SEQ(int32_t, INT32_MAX) \
OF_PP_MAKE_TUPLE_SEQ(int64_t, INT64_MAX) \
OF_PP_MAKE_TUPLE_SEQ(uint8_t, UINT8_MAX) \
OF_PP_MAKE_TUPLE_SEQ(uint16_t, UINT16_MAX) \
OF_PP_MAKE_TUPLE_SEQ(uint32_t, UINT32_MAX) \
APPLE_MAX_VAL_SEQ \
OF_PP_MAKE_TUPLE_SEQ(uint64_t, UINT64_MAX) \
OF_PP_MAKE_TUPLE_SEQ(float, FLT_MAX) \
OF_PP_MAKE_TUPLE_SEQ(double, DBL_MAX) \
OF_PP_MAKE_TUPLE_SEQ(bool, true)
#ifdef __APPLE__
#define APPLE_MIN_VAL_SEQ OF_PP_MAKE_TUPLE_SEQ(unsigned long, 0)
#else
#define APPLE_MIN_VAL_SEQ
#endif
#define MIN_VAL_SEQ \
OF_PP_MAKE_TUPLE_SEQ(int8_t, INT8_MIN) \
OF_PP_MAKE_TUPLE_SEQ(int16_t, INT16_MIN) \
OF_PP_MAKE_TUPLE_SEQ(int32_t, INT32_MIN) \
OF_PP_MAKE_TUPLE_SEQ(int64_t, INT64_MIN) \
OF_PP_MAKE_TUPLE_SEQ(uint8_t, 0) \
OF_PP_MAKE_TUPLE_SEQ(uint16_t, 0) \
OF_PP_MAKE_TUPLE_SEQ(uint32_t, 0) \
APPLE_MIN_VAL_SEQ \
OF_PP_MAKE_TUPLE_SEQ(uint64_t, 0) \
OF_PP_MAKE_TUPLE_SEQ(float, -FLT_MAX) \
OF_PP_MAKE_TUPLE_SEQ(double, -DBL_MAX) \
OF_PP_MAKE_TUPLE_SEQ(bool, false)
#define SPECIALIZE_MAX_VAL(T, limit_value) \
template<> \
OF_DEVICE_FUNC T GetMaxVal<T>() { \
return limit_value; \
}
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_MAX_VAL, MAX_VAL_SEQ);
#undef SPECIALIZE_MAX_VAL
#define SPECIALIZE_MIN_VAL(T, limit_value) \
template<> \
OF_DEVICE_FUNC T GetMinVal<T>() { \
return limit_value; \
}
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_MIN_VAL, MIN_VAL_SEQ);
#undef SPECIALIZE_MIN_VAL
template<typename T>
const T* GetZeroPtr() {
static const T ret = GetZeroVal<T>();
return &ret;
}
template<typename T>
const T* GetOnePtr() {
static const T ret = GetOneVal<T>();
return &ret;
}
template<typename T, typename std::enable_if<IsFloat16<T>::value>::type* = nullptr>
OF_DEVICE_FUNC T GetZeroVal() {
uint16_t ret = 0x0; // Decimal: 0; Binary: 0 00000 0000000000
return *(T*)&ret;
}
template<typename T, typename std::enable_if<IsFloat16<T>::value>::type* = nullptr>
OF_DEVICE_FUNC T GetOneVal() {
uint16_t ret = 0x3c00; // Decimal: 15360; Binary: 0 01111 0000000000
return *(T*)&ret;
}
template<typename T, typename std::enable_if<IsFloat16<T>::value>::type* = nullptr>
OF_DEVICE_FUNC T GetMaxVal() {
uint16_t ret = 0x7bff; // Decimal: 31743; Binary: 0 11110 1111111111
return *(T*)&ret;
}
template<typename T, typename std::enable_if<IsFloat16<T>::value>::type* = nullptr>
OF_DEVICE_FUNC T GetMinVal() {
uint16_t ret = 0xfbff; // Decimal: 64511; Binary: 1 11110 1111111111
return *(T*)&ret;
}
template<DeviceType, typename T>
struct DevDType {
typedef T type;
};
#if defined(WITH_CUDA)
template<>
struct DevDType<DeviceType::kCUDA, float16> {
static_assert(sizeof(float16) == sizeof(half), "sizeof(float16) != sizeof(half)");
typedef half type;
};
#endif
#if defined(WITH_ROCM)
template<>
struct DevDType<DeviceType::kCUDA, float16> {
static_assert(sizeof(float16) == sizeof(half), "sizeof(float16) != sizeof(half)");
typedef half type;
};
#endif
// Func
bool IsBoolDataType(DataType data_type);
bool IsIntegralDataType(DataType data_type);
bool IsFloatingDataType(DataType data_type);
bool IsSupportRequireGradDataType(DataType data_type);
bool IsPODDataType(DataType data_type);
bool IsPODAndHalfDataType(DataType data_type);
bool IsIndexDataType(DataType data_type);
bool NotSupportBoxingDataType(DataType data_type);
size_t GetSizeOfDataType(DataType data_type);
inline bool operator==(const OptInt64& lhs, const OptInt64& rhs) {
return (lhs.has_value() && rhs.has_value() && lhs.value() == rhs.value())
|| (!lhs.has_value() && !rhs.has_value());
}
template<typename T>
void CheckDataType(DataType data_type) {
LOG_IF(FATAL, (std::is_same<T, void>::value == false && std::is_same<T, char>::value == false && std::is_same<T, long>::value == false
&& data_type != DataType::kChar && data_type != GetDataType<T>::value))
<< data_type << " " << GetDataType<T>::value;
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DATA_TYPE_H_
syntax = "proto2";
package oneflow;
enum DataType {
kInvalidDataType = 0;
kChar = 1;
kFloat = 2;
kDouble = 3;
kInt8 = 4;
kInt32 = 5;
kInt64 = 6;
kUInt8 = 7;
kOFRecord = 8;
kFloat16 = 9;
kTensorBuffer = 10;
kBFloat16 = 11;
kBool = 12;
kUInt16 = 13;
kUInt32 = 14;
kUInt64 = 15;
kUInt128 = 16;
kInt16 = 17;
kInt128 = 18;
kComplex32 = 19;
kComplex64 = 20;
kComplex128 = 21;
}
message OptInt64 {
optional int64 value = 1 [ default = -1 ];
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_
#define ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_
#ifdef WITH_CUDA
#include <cuda_runtime.h>
#endif
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#endif
#include <cstdint>
#include <limits>
#include <type_traits>
#include "oneflow/core/common/data_type.h"
namespace oneflow {
template<typename T>
struct IsFloatingOrHalf {
static const bool value = IsFloating<T>::value || IsFloat16<T>::value;
};
template<typename T>
struct IsArithmeticOrHalf {
static const bool value = std::is_arithmetic<T>::value || IsFloat16<T>::value;
};
template<typename From, typename To>
struct NeedsClamp {
static const bool from_fp = IsFloatingOrHalf<From>::value;
static const bool to_fp = IsFloatingOrHalf<To>::value;
static const bool from_fp16 = IsFloat16<From>::value;
static const bool to_fp16 = IsFloat16<To>::value;
static const bool from_unsigned = std::is_unsigned<From>::value;
static const bool to_unsigned = std::is_unsigned<To>::value;
static const bool value =
// to smaller type of same kind (fp, int)
(from_fp == to_fp && sizeof(To) < sizeof(From)) ||
// fp32 has range in excess of (u)int64
(from_fp && !to_fp) ||
// converting to unsigned requires clamping negatives to zero
(!from_unsigned && to_unsigned) ||
// zero-extending signed unsigned integers requires more bits
(from_unsigned && !to_unsigned && sizeof(To) <= sizeof(From)) ||
// float16
(to_fp16 && sizeof(To) <= sizeof(From));
};
template<typename To>
struct NeedsClamp<bool, To> {
static const bool value = false;
};
template<typename T, typename U, typename Enabled = void>
struct ClampHelper {};
// floating-point and signed integer -> floating-point and signed integer
template<typename T, typename U>
struct ClampHelper<
T, U,
std::enable_if_t<
NeedsClamp<U, T>::value && std::is_signed<U>::value && std::is_signed<T>::value, void>> {
OF_DEVICE_FUNC static const T Call(U value) {
return value <= GetMinVal<T>() ? GetMinVal<T>()
: value >= GetMaxVal<T>() ? GetMaxVal<T>()
: static_cast<T>(value);
}
};
// floating-point -> unsigned types
template<typename T, typename U>
struct ClampHelper<T, U,
std::enable_if_t<NeedsClamp<U, T>::value && std::is_signed<U>::value
&& IsFloatingOrHalf<U>::value && std::is_unsigned<T>::value,
void>> {
OF_DEVICE_FUNC static const T Call(U value) {
return value <= GetMinVal<T>() ? GetMinVal<T>()
: value >= GetMaxVal<T>() ? GetMaxVal<T>()
: static_cast<T>(value);
}
};
// signed integer types -> unsigned types
template<typename T, typename U>
struct ClampHelper<T, U,
std::enable_if_t<NeedsClamp<U, T>::value && std::is_signed<U>::value
&& std::is_integral<U>::value && std::is_unsigned<T>::value,
void>> {
OF_DEVICE_FUNC static const T Call(U value) {
return value <= 0 ? 0
: static_cast<std::make_unsigned_t<U>>(value) >= GetMaxVal<T>() ? GetMaxVal<T>()
: static_cast<T>(value);
}
};
// unsigned types -> any types
template<typename T, typename U>
struct ClampHelper<T, U,
std::enable_if_t<NeedsClamp<U, T>::value && std::is_unsigned<U>::value, void>> {
OF_DEVICE_FUNC static const T Call(U value) {
return value >= GetMaxVal<T>() ? GetMaxVal<T>() : static_cast<T>(value);
}
};
// not clamp
template<typename T, typename U>
struct ClampHelper<T, U, std::enable_if_t<!NeedsClamp<U, T>::value, void>> {
OF_DEVICE_FUNC static const T Call(U value) { return value; }
};
OF_DEVICE_FUNC const int32_t Clamp(uint32_t value) {
return value & 0x80000000u ? 0x7fffffff : value;
}
OF_DEVICE_FUNC const uint32_t Clamp(int32_t value) { return value < 0 ? 0u : value; }
OF_DEVICE_FUNC const int32_t Clamp(int64_t value) {
return value < static_cast<int64_t>(GetMinVal<int32_t>()) ? GetMinVal<int32_t>()
: value > static_cast<int64_t>(GetMaxVal<int32_t>()) ? GetMaxVal<int32_t>()
: static_cast<int32_t>(value);
}
template<>
struct ClampHelper<int32_t, uint64_t> {
OF_DEVICE_FUNC static const int32_t Call(uint64_t value) {
return value > static_cast<uint64_t>(GetMaxVal<int32_t>()) ? GetMaxVal<int32_t>()
: static_cast<int32_t>(value);
}
};
template<>
struct ClampHelper<uint32_t, int64_t> {
OF_DEVICE_FUNC static const uint32_t Call(int64_t value) {
return value < 0 ? 0
: value > static_cast<int64_t>(GetMaxVal<uint32_t>()) ? GetMaxVal<uint32_t>()
: static_cast<uint32_t>(value);
}
};
template<>
struct ClampHelper<uint32_t, uint64_t> {
OF_DEVICE_FUNC static const uint32_t Call(uint64_t value) {
return value > static_cast<uint64_t>(GetMaxVal<uint32_t>()) ? GetMaxVal<uint32_t>()
: static_cast<uint32_t>(value);
}
};
template<typename T>
struct ClampHelper<bool, T> {
OF_DEVICE_FUNC static const bool Call(T value) { return static_cast<bool>(value); }
};
template<typename T>
struct ClampHelper<float16, T> {
inline static const float16 Call(T value) {
return static_cast<float16>(ClampHelper<T, float>::Call(value) < GetMinVal<float16>()
? GetMinVal<float16>()
: ClampHelper<T, float>::Call(value) > GetMaxVal<float16>()
? GetMaxVal<float16>()
: ClampHelper<T, float>::Call(value));
}
};
template<typename T>
struct ClampHelper<T, float16> {
inline static const T Call(float16 value) {
return ClampHelper<T, float>::Call(static_cast<float>(value));
}
};
inline const float16 Clamp(float16 value) { return value; }
template<typename T, typename U>
OF_DEVICE_FUNC const T Clamp(U value) {
return ClampHelper<T, U>::Call(value);
}
namespace {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
inline __device__ int cuda_round_helper(float f, int) { return __float2int_rn(f); }
inline __device__ unsigned cuda_round_helper(float f, unsigned) { return __float2uint_rn(f); }
inline __device__ long long cuda_round_helper(float f, long long) {
return __float2ll_rd(f + 0.5f);
}
inline __device__ unsigned long long cuda_round_helper(float f, unsigned long long) {
return __float2ull_rd(f + 0.5f);
}
inline __device__ long cuda_round_helper(float f, long) {
return sizeof(long) == sizeof(int) ? __float2int_rn(f) : __float2ll_rd(f + 0.5f);
}
inline __device__ unsigned long cuda_round_helper(float f, unsigned long) {
return sizeof(unsigned long) == sizeof(unsigned int) ? __float2uint_rn(f)
: __float2ull_rd(f + 0.5f);
}
inline __device__ int cuda_round_helper(double f, int) { return __double2int_rn(f); }
inline __device__ unsigned cuda_round_helper(double f, unsigned) { return __double2uint_rn(f); }
inline __device__ long long cuda_round_helper(double f, long long) {
return __double2ll_rd(f + 0.5f);
}
inline __device__ unsigned long long cuda_round_helper(double f, unsigned long long) {
return __double2ull_rd(f + 0.5f);
}
inline __device__ long cuda_round_helper(double f, long) {
return sizeof(long) == sizeof(int) ? __double2int_rn(f) : __double2ll_rd(f + 0.5f);
}
inline __device__ unsigned long cuda_round_helper(double f, unsigned long) {
return sizeof(unsigned long) == sizeof(unsigned int) ? __double2uint_rn(f)
: __double2ull_rd(f + 0.5f);
}
#endif
template<typename Out, typename In, bool OutIsFp = IsFloatingOrHalf<Out>::value,
bool InIsFp = IsFloatingOrHalf<In>::value>
struct ConverterBase;
template<typename Out, typename In>
struct Converter : ConverterBase<Out, In> {
static_assert(IsArithmeticOrHalf<Out>::value && IsArithmeticOrHalf<In>::value,
"Default ConverterBase can only be used with arithmetic types.");
};
// Converts between two FP types
template<typename Out, typename In>
struct ConverterBase<Out, In, true, true> {
OF_DEVICE_FUNC static const Out Convert(In value) { return value; }
OF_DEVICE_FUNC static const Out ConvertNorm(In value) { return value; }
OF_DEVICE_FUNC static const Out ConvertSat(In value) { return value; }
OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) { return value; }
};
// Converts integral to FP type
template<typename Out, typename In>
struct ConverterBase<Out, In, true, false> {
OF_DEVICE_FUNC static const Out Convert(In value) { return value; }
OF_DEVICE_FUNC static const Out ConvertSat(In value) { return value; }
OF_DEVICE_FUNC static const Out ConvertNorm(In value) {
return value * (Out(1) / (GetMaxVal<In>()));
}
OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) {
return value * (Out(1) / (GetMaxVal<In>()));
}
};
// Converts integral to float16
template<typename In>
struct ConverterBase<float16, In, true, false> {
OF_DEVICE_FUNC static const float16 Convert(In value) {
auto out = ConverterBase<float, In, true, false>::Convert(value);
return static_cast<float16>(out);
}
OF_DEVICE_FUNC static const float16 ConvertSat(In value) {
auto out = ConverterBase<float, In, true, false>::ConvertSat(value);
return static_cast<float16>(out);
}
OF_DEVICE_FUNC static const float16 ConvertNorm(In value) {
auto out = ConverterBase<float, In, true, false>::ConvertNorm(value);
return static_cast<float16>(out);
}
OF_DEVICE_FUNC static const float16 ConvertSatNorm(In value) {
auto out = ConverterBase<float, In, true, false>::ConvertSatNorm(value);
return static_cast<float16>(out);
}
};
// Converts FP to integral type
template<typename Out, typename In>
struct ConverterBase<Out, In, false, true> {
OF_DEVICE_FUNC static const Out Convert(In value) {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return Clamp<Out>(cuda_round_helper(value, Out()));
#else
return Clamp<Out>(std::round(value));
#endif
}
OF_DEVICE_FUNC static const Out ConvertSat(In value) {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return Clamp<Out>(cuda_round_helper(value, Out()));
#else
return Clamp<Out>(std::round(value));
#endif
}
OF_DEVICE_FUNC static const Out ConvertNorm(In value) {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return Clamp<Out>(cuda_round_helper(value * GetMaxVal<Out>(), Out()));
#else
return std::round(value * GetMaxVal<Out>());
#endif
}
OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return std::is_signed<Out>::value
? Clamp<Out>(cuda_round_helper(value * GetMaxVal<Out>(), Out()))
: cuda_round_helper(GetMaxVal<Out>() * __saturatef(value), Out());
#else
return Clamp<Out>(std::round(value * GetMaxVal<Out>()));
#endif
}
};
// Converts signed to signed, unsigned to unsigned or unsigned to signed
template<typename Out, typename In, bool IsOutSigned = std::is_signed<Out>::value,
bool IsInSigned = std::is_signed<In>::value>
struct ConvertIntInt {
OF_DEVICE_FUNC static const Out Convert(In value) { return value; }
OF_DEVICE_FUNC static const Out ConvertNorm(In value) {
return Converter<Out, float>::Convert(value * (1.0f * GetMaxVal<Out>() / GetMaxVal<In>()));
}
OF_DEVICE_FUNC static const Out ConvertSat(In value) { return Clamp<Out>(value); }
OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) { return ConvertNorm(value); }
};
// Converts signed to unsigned integer
template<typename Out, typename In>
struct ConvertIntInt<Out, In, false, true> {
OF_DEVICE_FUNC static const Out Convert(In value) { return value; }
OF_DEVICE_FUNC static const Out ConvertNorm(In value) {
return Converter<Out, float>::Convert(value * (1.0f * GetMaxVal<Out>() / GetMaxVal<In>()));
}
OF_DEVICE_FUNC static const Out ConvertSat(In value) { return Clamp<Out>(value); }
OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return cuda_round_helper(__saturatef(value * (1.0f / GetMaxVal<In>())) * GetMaxVal<Out>());
}
#else
return value < 0 ? 0 : ConvertNorm(value);
}
#endif
};
// Converts between integral types
template<typename Out, typename In>
struct ConverterBase<Out, In, false, false> : ConvertIntInt<Out, In> {
static_assert(IsArithmeticOrHalf<Out>::value && IsArithmeticOrHalf<In>::value,
"Default ConverterBase can only be used with arithmetic types.");
};
// Pass-through conversion
template<typename T>
struct Converter<T, T> {
static OF_DEVICE_FUNC const T Convert(T value) { return value; }
static OF_DEVICE_FUNC const T ConvertSat(T value) { return value; }
static OF_DEVICE_FUNC const T ConvertNorm(T value) { return value; }
static OF_DEVICE_FUNC const T ConvertSatNorm(T value) { return value; }
};
template<typename raw_out, typename raw_in>
using converter_t =
Converter<std::remove_cv_t<raw_out>, std::remove_cv_t<std::remove_reference_t<raw_in>>>;
} // namespace
template<typename Out, typename In>
OF_DEVICE_FUNC const Out Convert(In value) {
return converter_t<Out, In>::Convert(value);
}
template<typename Out, typename In>
OF_DEVICE_FUNC const Out ConvertNorm(In value) {
return converter_t<Out, In>::ConvertNorm(value);
}
template<typename Out, typename In>
OF_DEVICE_FUNC const Out ConvertSat(In value) {
return converter_t<Out, In>::ConvertSat(value);
}
template<typename Out, typename In>
OF_DEVICE_FUNC const Out ConvertSatNorm(In value) {
return converter_t<Out, In>::ConvertSatNorm(value);
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "gtest/gtest.h"
#include "util.h"
#include "oneflow/core/common/data_type_converter.h"
#include "oneflow/core/common/data_type_converter_test_static.h"
#ifdef __CUDA_ARCH__
#include <cuda_runtime.h>
#elif defined(__HIP_DEVICE_COMPILE__)
#include <hip/hip_runtime.h>
#else
#include <cmath>
#endif
namespace oneflow {
namespace {
// cpp17 std::clamp possible implementation
template<class T>
constexpr const T& clamp(const T& v, const T& lo, const T& hi) {
return (v < lo) ? lo : (hi < v) ? hi : v;
}
} // namespace
TEST(ClampTest, Clamp) {
ASSERT_TRUE(Clamp<uint8_t>(0) == 0);
ASSERT_TRUE(Clamp<uint8_t>(255) == 255);
ASSERT_TRUE(Clamp<uint8_t>(100) == 100);
ASSERT_TRUE(Clamp<uint8_t>(100.3) == 100);
ASSERT_TRUE(Clamp<uint8_t>(256) == 255);
ASSERT_TRUE(Clamp<uint8_t>(-4) == 0);
ASSERT_TRUE(Clamp<uint8_t>(-4.0f) == 0);
ASSERT_TRUE(Clamp<uint8_t>(1e+20f) == 255);
ASSERT_TRUE(Clamp<uint8_t>(-1e+20f) == 0);
ASSERT_TRUE(Clamp<uint8_t>(1e+200) == 255);
ASSERT_TRUE(Clamp<uint8_t>(-1e+200) == 0);
ASSERT_TRUE(Clamp<int8_t>(-4) == -4);
ASSERT_TRUE(Clamp<int8_t>(-4.2) == -4);
ASSERT_TRUE(Clamp<int8_t>(4.2) == 4);
ASSERT_TRUE(Clamp<int8_t>(127) == 127);
ASSERT_TRUE(Clamp<int8_t>(128) == 127);
ASSERT_TRUE(Clamp<int8_t>(256) == 127);
ASSERT_TRUE(Clamp<int8_t>(-128) == -128);
ASSERT_TRUE(Clamp<int8_t>(-256) == -128);
ASSERT_TRUE(Clamp<int8_t>(1e+20f) == 127);
ASSERT_TRUE(Clamp<int8_t>(-1e+20f) == -128);
ASSERT_TRUE(Clamp<int8_t>(1e+200) == 127);
ASSERT_TRUE(Clamp<int8_t>(-1e+200) == -128);
ASSERT_TRUE(Clamp<uint16_t>(0) == 0);
ASSERT_TRUE(Clamp<uint16_t>(0xffff) == 0xffff);
ASSERT_TRUE(Clamp<uint16_t>(100) == 100);
ASSERT_TRUE(Clamp<uint16_t>(100.3) == 100);
ASSERT_TRUE(Clamp<uint16_t>(0x10000) == 0xffff);
ASSERT_TRUE(Clamp<uint16_t>(-4) == 0);
ASSERT_TRUE(Clamp<uint16_t>(-4.0f) == 0);
ASSERT_TRUE(Clamp<uint16_t>(1e+20f) == 0xffff);
ASSERT_TRUE(Clamp<uint16_t>(-1e+20f) == 0);
ASSERT_TRUE(Clamp<uint16_t>(1e+200) == 0xffff);
ASSERT_TRUE(Clamp<uint16_t>(-1e+200) == 0);
ASSERT_TRUE(Clamp<int16_t>(-4) == -4);
ASSERT_TRUE(Clamp<int16_t>(-4.2) == -4);
ASSERT_TRUE(Clamp<int16_t>(4.2) == 4);
ASSERT_TRUE(Clamp<int16_t>(0x7fff) == 0x7fff);
ASSERT_TRUE(Clamp<int16_t>(0x8000) == 0x7fff);
ASSERT_TRUE(Clamp<int16_t>(0x10000) == 0x7fff);
ASSERT_TRUE(Clamp<int16_t>(-0x8000) == -0x8000);
ASSERT_TRUE(Clamp<int16_t>(-0x10000) == -0x8000);
ASSERT_TRUE(Clamp<int16_t>(1e+20f) == 0x7fff);
ASSERT_TRUE(Clamp<int16_t>(-1e+20f) == -0x8000);
ASSERT_TRUE(Clamp<int16_t>(1e+200) == 0x7fff);
ASSERT_TRUE(Clamp<int16_t>(-1e+200) == -0x8000);
ASSERT_TRUE(Clamp<uint32_t>(0) == 0);
ASSERT_TRUE(Clamp<uint32_t>(0xffffffffLL) == 0xffffffffLL);
ASSERT_TRUE(Clamp<uint32_t>(100) == 100);
ASSERT_TRUE(Clamp<uint32_t>(100.3) == 100);
ASSERT_TRUE(Clamp<uint32_t>(0x100000000LL) == 0xffffffffLL);
ASSERT_TRUE(Clamp<uint32_t>(-4) == 0);
ASSERT_TRUE(Clamp<uint32_t>(-4.0f) == 0);
ASSERT_TRUE(Clamp<uint32_t>(1e+20f) == 0xffffffffu);
ASSERT_TRUE(Clamp<uint32_t>(-1.0e+20f) == 0);
ASSERT_TRUE(Clamp<uint32_t>(1e+200) == 0xffffffffu);
ASSERT_TRUE(Clamp<uint32_t>(-1.0e+200) == 0);
ASSERT_TRUE(Clamp<int32_t>(-4) == -4);
ASSERT_TRUE(Clamp<int32_t>(-4LL) == -4);
ASSERT_TRUE(Clamp<int32_t>(-4.2) == -4);
ASSERT_TRUE(Clamp<int32_t>(4.2) == 4);
ASSERT_TRUE(Clamp<int32_t>(0x7fffffff) == 0x7fffffff);
ASSERT_TRUE(Clamp<int32_t>(0x80000000L) == 0x7fffffff);
ASSERT_TRUE(Clamp<int32_t>(0x100000000L) == 0x7fffffff);
ASSERT_TRUE(Clamp<int32_t>(-0x80000000LL) == -0x7fffffff - 1);
ASSERT_TRUE(Clamp<int32_t>(-0x100000000LL) == -0x7fffffff - 1);
ASSERT_TRUE(Clamp<int32_t>(1.0e+20f) == 0x7fffffff);
ASSERT_TRUE(Clamp<int32_t>(-1.0e+20f) == -0x80000000L);
ASSERT_TRUE(Clamp<int32_t>(1.0e+200) == 0x7fffffff);
ASSERT_TRUE(Clamp<int32_t>(-1.0e+200) == -0x80000000L);
ASSERT_TRUE(Clamp<int64_t>(1.0e+200) == 0x7fffffffffffffffLL);
ASSERT_TRUE(Clamp<int64_t>(-1.0e+200) == -0x7fffffffffffffffLL - 1);
ASSERT_TRUE(Clamp<uint64_t>(1.0e+200) == 0xffffffffffffffffULL);
ASSERT_TRUE(Clamp<uint64_t>(-1.0e+200) == 0);
}
TEST(ConvertSat, float2int) {
FOR_RANGE(int32_t, exp, -10, 100) {
FOR_RANGE(float, sig, -256, 257) {
float f = ldexpf(sig, exp);
float integral;
float fract = modff(f, &integral);
if (fract == 0.5f || fract == -0.5f) continue;
double rounded = roundf(f);
int64_t clamped = clamp<double>(rounded, -128, 127);
ASSERT_EQ(ConvertSat<int8_t>(f), clamped) << " with f = " << f;
clamped = clamp<double>(rounded, 0, 255);
ASSERT_EQ(ConvertSat<uint8_t>(f), clamped) << " with f = " << f;
clamped = clamp<double>(rounded, -0x8000, 0x7fff);
ASSERT_EQ(ConvertSat<int16_t>(f), clamped) << " with f = " << f;
clamped = clamp<double>(rounded, 0, 0xffff);
ASSERT_EQ(ConvertSat<uint16_t>(f), clamped) << " with f = " << f;
clamped = clamp<double>(rounded, int32_t(~0x7fffffff), 0x7fffffff);
ASSERT_EQ(ConvertSat<int32_t>(f), clamped) << " with f = " << f;
clamped = clamp<double>(rounded, 0, 0xffffffffu);
ASSERT_EQ(ConvertSat<uint32_t>(f), clamped) << " with f = " << f;
}
}
}
TEST(ConvertNorm, int2int) {
EXPECT_EQ((ConvertNorm<uint8_t, uint8_t>(0)), 0);
EXPECT_EQ((ConvertNorm<uint8_t, int8_t>(127)), 255);
}
TEST(ConvertNorm, float2int) {
EXPECT_EQ(ConvertNorm<uint8_t>(0.0f), 0);
EXPECT_EQ(ConvertNorm<uint8_t>(0.499f), 127);
EXPECT_EQ(ConvertNorm<uint8_t>(1.0f), 255);
EXPECT_EQ(ConvertNorm<int8_t>(1.0f), 127);
EXPECT_EQ(ConvertNorm<int8_t>(0.499f), 63);
EXPECT_EQ(ConvertNorm<int8_t>(-1.0f), -127);
EXPECT_EQ(ConvertNorm<uint16_t>(0.0f), 0);
EXPECT_EQ(ConvertNorm<uint16_t>(1.0f), 0xffff);
EXPECT_EQ(ConvertNorm<int16_t>(1.0f), 0x7fff);
EXPECT_EQ(ConvertNorm<int16_t>(-1.0f), -0x7fff);
}
TEST(ConvertSatNorm, float2int) {
EXPECT_EQ(ConvertSatNorm<uint8_t>(2.0f), 255);
EXPECT_EQ(ConvertSatNorm<uint8_t>(0.499f), 127);
EXPECT_EQ(ConvertSatNorm<uint8_t>(-2.0f), 0);
EXPECT_EQ(ConvertSatNorm<int8_t>(2.0f), 127);
EXPECT_EQ(ConvertSatNorm<int8_t>(0.499f), 63);
EXPECT_EQ(ConvertSatNorm<int8_t>(-2.0f), -128);
EXPECT_EQ(ConvertSatNorm<uint8_t>(0.4f / 255), 0);
EXPECT_EQ(ConvertSatNorm<uint8_t>(0.6f / 255), 1);
EXPECT_EQ(ConvertSatNorm<int16_t>(2.0f), 0x7fff);
EXPECT_EQ(ConvertSatNorm<int16_t>(-2.0f), -0x8000);
}
TEST(ConvertNorm, int2float) {
EXPECT_EQ((ConvertNorm<float, uint8_t>(255)), 1.0f);
EXPECT_NEAR((ConvertNorm<float, uint8_t>(127)), 1.0f * 127 / 255, 1e-7f);
EXPECT_EQ((ConvertNorm<float, int8_t>(127)), 1.0f);
EXPECT_NEAR((ConvertNorm<float, int8_t>(64)), 1.0f * 64 / 127, 1e-7f);
}
TEST(Clamp1, int64_2_float16) {
int64_t big_num = 0x0FFFFFFFFFFFFFFF;
EXPECT_EQ(static_cast<float>(Clamp<float16>(big_num)), Clamp<float16>(Clamp<float>(big_num)));
EXPECT_EQ(65504.0f, Clamp<float16>(big_num));
EXPECT_EQ(-65504.0f, Clamp<float16>(-big_num));
}
TEST(Clamp2, float16_2_int64) {
float16 fp16 = static_cast<float16>(65504.0f);
EXPECT_EQ(65504, Clamp<int64_t>(fp16));
EXPECT_EQ(-65504, Clamp<int64_t>(-fp16));
}
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_
#define ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_
#include "oneflow/core/common/data_type_converter.h"
namespace oneflow {
namespace {
// fp to int
static_assert(NeedsClamp<float, int8_t>::value, "Float range exceeds all ints up to 64b");
static_assert(NeedsClamp<float, uint8_t>::value, "Float range exceeds all ints up to 64b");
static_assert(NeedsClamp<float, int16_t>::value, "Float range exceeds all ints up to 64b");
static_assert(NeedsClamp<float, uint16_t>::value, "Float range exceeds all ints up to 64b");
static_assert(NeedsClamp<float, int32_t>::value, "Float range exceeds all ints up to 64b");
static_assert(NeedsClamp<float, uint32_t>::value, "Float range exceeds all ints up to 64b");
static_assert(NeedsClamp<float, int64_t>::value, "Float range exceeds all ints up to 64b");
static_assert(NeedsClamp<float, uint64_t>::value, "Float range exceeds all ints up to 64b");
// same size, different signedness
static_assert(NeedsClamp<int8_t, uint8_t>::value, "Signed <-> unsigned requires clamp");
static_assert(NeedsClamp<uint8_t, int8_t>::value, "Signed <-> unsigned requires clamp");
static_assert(NeedsClamp<int16_t, uint16_t>::value, "Signed <-> unsigned requires clamp");
static_assert(NeedsClamp<uint16_t, int16_t>::value, "Signed <-> unsigned requires clamp");
static_assert(NeedsClamp<int32_t, uint32_t>::value, "Signed <-> unsigned requires clamp");
static_assert(NeedsClamp<uint32_t, int32_t>::value, "Signed <-> unsigned requires clamp");
static_assert(NeedsClamp<int64_t, uint64_t>::value, "Signed <-> unsigned requires clamp");
static_assert(NeedsClamp<uint64_t, int64_t>::value, "Signed <-> unsigned requires clamp");
// larger, but unsigned
static_assert(NeedsClamp<int8_t, uint16_t>::value, "Need to clamp negatives to 0");
static_assert(NeedsClamp<int8_t, uint32_t>::value, "Need to clamp negatives to 0");
static_assert(NeedsClamp<int8_t, uint64_t>::value, "Need to clamp negatives to 0");
static_assert(NeedsClamp<int16_t, uint32_t>::value, "Need to clamp negatives to 0");
static_assert(NeedsClamp<int16_t, uint64_t>::value, "Need to clamp negatives to 0");
static_assert(NeedsClamp<int32_t, uint64_t>::value, "Need to clamp negatives to 0");
static_assert(!NeedsClamp<int8_t, int8_t>::value, "Clamping not required");
static_assert(!NeedsClamp<int8_t, int16_t>::value, "Clamping not required");
static_assert(!NeedsClamp<uint8_t, int16_t>::value, "Clamping not required");
static_assert(!NeedsClamp<uint8_t, uint16_t>::value, "Clamping not required");
static_assert(!NeedsClamp<float, float>::value, "Clamping not required");
static_assert(!NeedsClamp<float, double>::value, "Clamping not required");
} // namespace
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DATA_TYPE_CONVERTER_TEST_STATIC_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment