Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
...@@ -26,10 +26,10 @@ IBVerbsMemDesc::IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size) ...@@ -26,10 +26,10 @@ IBVerbsMemDesc::IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size)
mr_ = ibv::wrapper.ibv_reg_mr_wrap( mr_ = ibv::wrapper.ibv_reg_mr_wrap(
pd, mem_ptr, byte_size, pd, mem_ptr, byte_size,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ); IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ);
CHECK(mr_); PCHECK(mr_);
} }
IBVerbsMemDesc::~IBVerbsMemDesc() { CHECK_EQ(ibv::wrapper.ibv_dereg_mr(mr_), 0); } IBVerbsMemDesc::~IBVerbsMemDesc() { PCHECK(ibv::wrapper.ibv_dereg_mr(mr_) == 0); }
} // namespace oneflow } // namespace oneflow
......
...@@ -40,7 +40,7 @@ IBVerbsQP::IBVerbsQP(ibv_context* ctx, ibv_pd* pd, const struct ibv_port_attr& p ...@@ -40,7 +40,7 @@ IBVerbsQP::IBVerbsQP(ibv_context* ctx, ibv_pd* pd, const struct ibv_port_attr& p
port_num_ = port_num; port_num_ = port_num;
// qp_ // qp_
ibv_device_attr device_attr{}; ibv_device_attr device_attr{};
CHECK_EQ(ibv::wrapper.ibv_query_device(ctx, &device_attr), 0); PCHECK(ibv::wrapper.ibv_query_device(ctx, &device_attr) == 0);
const int64_t user_queue_depth = const int64_t user_queue_depth =
ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_QUEUE_DEPTH", kDefaultQueueDepth); ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_QUEUE_DEPTH", kDefaultQueueDepth);
const uint32_t queue_depth = std::min<uint32_t>(device_attr.max_qp_wr, user_queue_depth); const uint32_t queue_depth = std::min<uint32_t>(device_attr.max_qp_wr, user_queue_depth);
...@@ -57,7 +57,7 @@ IBVerbsQP::IBVerbsQP(ibv_context* ctx, ibv_pd* pd, const struct ibv_port_attr& p ...@@ -57,7 +57,7 @@ IBVerbsQP::IBVerbsQP(ibv_context* ctx, ibv_pd* pd, const struct ibv_port_attr& p
qp_init_attr.qp_type = IBV_QPT_RC; qp_init_attr.qp_type = IBV_QPT_RC;
qp_init_attr.sq_sig_all = 1; qp_init_attr.sq_sig_all = 1;
qp_ = ibv::wrapper.ibv_create_qp(pd, &qp_init_attr); qp_ = ibv::wrapper.ibv_create_qp(pd, &qp_init_attr);
CHECK(qp_); PCHECK(qp_);
// recv_msg_buf_ // recv_msg_buf_
recv_msg_buf_.assign(queue_depth, nullptr); recv_msg_buf_.assign(queue_depth, nullptr);
FOR_RANGE(size_t, i, 0, recv_msg_buf_.size()) { recv_msg_buf_.at(i) = new ActorMsgMR(pd_); } FOR_RANGE(size_t, i, 0, recv_msg_buf_.size()) { recv_msg_buf_.at(i) = new ActorMsgMR(pd_); }
...@@ -71,7 +71,7 @@ IBVerbsQP::IBVerbsQP(ibv_context* ctx, ibv_pd* pd, const struct ibv_port_attr& p ...@@ -71,7 +71,7 @@ IBVerbsQP::IBVerbsQP(ibv_context* ctx, ibv_pd* pd, const struct ibv_port_attr& p
} }
IBVerbsQP::~IBVerbsQP() { IBVerbsQP::~IBVerbsQP() {
CHECK_EQ(ibv::wrapper.ibv_destroy_qp(qp_), 0); PCHECK(ibv::wrapper.ibv_destroy_qp(qp_) == 0);
while (send_msg_buf_.empty() == false) { while (send_msg_buf_.empty() == false) {
delete send_msg_buf_.front(); delete send_msg_buf_.front();
send_msg_buf_.pop(); send_msg_buf_.pop();
...@@ -90,9 +90,9 @@ void IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) { ...@@ -90,9 +90,9 @@ void IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) {
qp_attr.port_num = port_num_; qp_attr.port_num = port_num_;
qp_attr.qp_access_flags = qp_attr.qp_access_flags =
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
CHECK_EQ(ibv::wrapper.ibv_modify_qp( PCHECK(ibv::wrapper.ibv_modify_qp(
qp_, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS), qp_, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)
0); == 0);
// IBV_QPS_RTR // IBV_QPS_RTR
memset(&qp_attr, 0, sizeof(ibv_qp_attr)); memset(&qp_attr, 0, sizeof(ibv_qp_attr));
...@@ -120,11 +120,11 @@ void IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) { ...@@ -120,11 +120,11 @@ void IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) {
qp_attr.rq_psn = 0; qp_attr.rq_psn = 0;
qp_attr.max_dest_rd_atomic = 1; qp_attr.max_dest_rd_atomic = 1;
qp_attr.min_rnr_timer = 12; qp_attr.min_rnr_timer = 12;
CHECK_EQ(ibv::wrapper.ibv_modify_qp(qp_, &qp_attr, PCHECK(ibv::wrapper.ibv_modify_qp(qp_, &qp_attr,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN
| IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC
| IBV_QP_MIN_RNR_TIMER), | IBV_QP_MIN_RNR_TIMER)
0); == 0);
// IBV_QPS_RTS // IBV_QPS_RTS
memset(&qp_attr, 0, sizeof(ibv_qp_attr)); memset(&qp_attr, 0, sizeof(ibv_qp_attr));
...@@ -134,11 +134,10 @@ void IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) { ...@@ -134,11 +134,10 @@ void IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) {
qp_attr.retry_cnt = 7; qp_attr.retry_cnt = 7;
qp_attr.rnr_retry = 7; qp_attr.rnr_retry = 7;
qp_attr.timeout = 14; qp_attr.timeout = 14;
CHECK_EQ(ibv::wrapper.ibv_modify_qp(qp_, &qp_attr, PCHECK(ibv::wrapper.ibv_modify_qp(qp_, &qp_attr,
IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC
| IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_TIMEOUT), | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_TIMEOUT)
== 0);
0);
} }
void IBVerbsQP::PostAllRecvRequest() { void IBVerbsQP::PostAllRecvRequest() {
...@@ -197,7 +196,7 @@ void IBVerbsQP::EnqueuePostSendReadWR(ibv_send_wr wr, ibv_sge sge) { ...@@ -197,7 +196,7 @@ void IBVerbsQP::EnqueuePostSendReadWR(ibv_send_wr wr, ibv_sge sge) {
if (num_outstanding_send_wr_ < max_outstanding_send_wr_) { if (num_outstanding_send_wr_ < max_outstanding_send_wr_) {
num_outstanding_send_wr_++; num_outstanding_send_wr_++;
ibv_send_wr* bad_wr = nullptr; ibv_send_wr* bad_wr = nullptr;
CHECK_EQ(ibv_post_send(qp_, &wr, &bad_wr), 0); PCHECK(ibv_post_send(qp_, &wr, &bad_wr) == 0);
} else { } else {
std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::make_pair(wr, sge); std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::make_pair(wr, sge);
pending_send_wr_queue_.push(ibv_send_wr_sge); pending_send_wr_queue_.push(ibv_send_wr_sge);
...@@ -239,7 +238,7 @@ void IBVerbsQP::PostPendingSendWR() { ...@@ -239,7 +238,7 @@ void IBVerbsQP::PostPendingSendWR() {
wr.sg_list = &ibv_send_wr_sge.second; wr.sg_list = &ibv_send_wr_sge.second;
pending_send_wr_queue_.pop(); pending_send_wr_queue_.pop();
ibv_send_wr* bad_wr = nullptr; ibv_send_wr* bad_wr = nullptr;
CHECK_EQ(ibv_post_send(qp_, &wr, &bad_wr), 0); PCHECK(ibv_post_send(qp_, &wr, &bad_wr) == 0);
} else { } else {
if (num_outstanding_send_wr_ > 0) { num_outstanding_send_wr_--; } if (num_outstanding_send_wr_ > 0) { num_outstanding_send_wr_--; }
} }
...@@ -258,7 +257,7 @@ void IBVerbsQP::PostRecvRequest(ActorMsgMR* msg_mr) { ...@@ -258,7 +257,7 @@ void IBVerbsQP::PostRecvRequest(ActorMsgMR* msg_mr) {
wr.sg_list = &sge; wr.sg_list = &sge;
wr.num_sge = 1; wr.num_sge = 1;
ibv_recv_wr* bad_wr = nullptr; ibv_recv_wr* bad_wr = nullptr;
CHECK_EQ(ibv_post_recv(qp_, &wr, &bad_wr), 0); PCHECK(ibv_post_recv(qp_, &wr, &bad_wr) == 0);
} }
ActorMsgMR* IBVerbsQP::GetOneSendMsgMRFromBuf() { ActorMsgMR* IBVerbsQP::GetOneSendMsgMRFromBuf() {
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BFLOAT16_H_
#define ONEFLOW_CORE_COMMON_BFLOAT16_H_
#include <stdint.h>
#include <limits>
#include <cmath>
#include <cstring>
namespace oneflow {
#if defined(__CUDACC__)
#define OF_DEVICE_FUNCTION __device__ __host__ __forceinline__
#else
#define OF_DEVICE_FUNCTION inline
#endif
struct alignas(2) bfloat16 {
uint16_t x;
bfloat16() = default;
bfloat16(const bfloat16& o) = default;
bfloat16& operator=(const bfloat16& o) = default;
bfloat16(bfloat16&& o) = default;
bfloat16& operator=(bfloat16&& o) = default;
~bfloat16() = default;
struct from_bits_t {};
static constexpr inline from_bits_t from_bits() { return from_bits_t(); }
constexpr inline bfloat16(unsigned short bits, from_bits_t) : x(bits){};
// reference: pytorch/c10/util/BFloat16.h
// https://github.com/pytorch/pytorch/blob/release/1.12/c10/util/BFloat16.h
bfloat16(float value) {
if (std::isnan(value)) {
x = 0x7FC0;
} else {
union {
uint32_t U32;
float F32;
};
F32 = value;
uint32_t rounding_bias = ((U32 >> 16) & 1) + 0x7FFFU;
x = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}
}
inline operator float() const {
float res = 0;
uint32_t tmp = x;
tmp <<= 16;
std::memcpy(&res, &tmp, sizeof(tmp));
return res;
}
inline bool operator==(const bfloat16& other) const { return x == other.x; }
inline explicit operator bool() const { return (x & 0x7fff) != 0; }
inline explicit operator int8_t() const { return static_cast<int8_t>(static_cast<float>(*this)); }
inline explicit operator uint8_t() const {
return static_cast<uint8_t>(static_cast<float>(*this));
}
inline explicit operator int16_t() const {
return static_cast<int16_t>(static_cast<float>(*this));
}
inline explicit operator uint16_t() const {
return static_cast<uint16_t>(static_cast<float>(*this));
}
inline explicit operator int32_t() const {
return static_cast<int32_t>(static_cast<float>(*this));
}
inline explicit operator uint32_t() const {
return static_cast<uint32_t>(static_cast<float>(*this));
}
inline explicit operator int64_t() const {
return static_cast<int64_t>(static_cast<float>(*this));
}
inline explicit operator uint64_t() const {
return static_cast<uint64_t>(static_cast<float>(*this));
}
inline explicit operator double() const { return static_cast<double>(static_cast<float>(*this)); }
};
// Arithmetic
inline bfloat16 operator+(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline bfloat16 operator-(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline bfloat16 operator*(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline bfloat16 operator/(const bfloat16& a, const bfloat16& b) {
return static_cast<float>(a) / static_cast<float>(b);
}
inline bfloat16 operator-(const bfloat16& a) {
bfloat16 output;
output.x = a.x ^ 0x8000U;
return output;
}
inline bfloat16& operator+=(bfloat16& a, const bfloat16& b) {
a = a + b;
return a;
}
inline bfloat16& operator-=(bfloat16& a, const bfloat16& b) {
a = a - b;
return a;
}
inline bfloat16& operator*=(bfloat16& a, const bfloat16& b) {
a = a * b;
return a;
}
inline bfloat16& operator/=(bfloat16& a, const bfloat16& b) {
a = a / b;
return a;
}
inline bfloat16& operator|(bfloat16& a, const bfloat16& b) {
a.x = a.x | b.x;
return a;
}
inline bfloat16& operator^(bfloat16& a, const bfloat16& b) {
a.x = a.x ^ b.x;
return a;
}
inline bfloat16& operator&(bfloat16& a, const bfloat16& b) {
a.x = a.x & b.x;
return a;
}
// Arithmetic with floats
inline float operator+(bfloat16 a, float b) { return static_cast<float>(a) + b; }
inline float operator-(bfloat16 a, float b) { return static_cast<float>(a) - b; }
inline float operator*(bfloat16 a, float b) { return static_cast<float>(a) * b; }
inline float operator/(bfloat16 a, float b) { return static_cast<float>(a) / b; }
inline float operator+(float a, bfloat16 b) { return a + static_cast<float>(b); }
inline float operator-(float a, bfloat16 b) { return a - static_cast<float>(b); }
inline float operator*(float a, bfloat16 b) { return a * static_cast<float>(b); }
inline float operator/(float a, bfloat16 b) { return a / static_cast<float>(b); }
inline float& operator+=(float& a, const bfloat16& b) { return a += static_cast<float>(b); }
inline float& operator-=(float& a, const bfloat16& b) { return a -= static_cast<float>(b); }
inline float& operator*=(float& a, const bfloat16& b) { return a *= static_cast<float>(b); }
inline float& operator/=(float& a, const bfloat16& b) { return a /= static_cast<float>(b); }
// Arithmetic with doubles
inline double operator+(bfloat16 a, double b) { return static_cast<double>(a) + b; }
inline double operator-(bfloat16 a, double b) { return static_cast<double>(a) - b; }
inline double operator*(bfloat16 a, double b) { return static_cast<double>(a) * b; }
inline double operator/(bfloat16 a, double b) { return static_cast<double>(a) / b; }
inline double operator+(double a, bfloat16 b) { return a + static_cast<double>(b); }
inline double operator-(double a, bfloat16 b) { return a - static_cast<double>(b); }
inline double operator*(double a, bfloat16 b) { return a * static_cast<double>(b); }
inline double operator/(double a, bfloat16 b) { return a / static_cast<double>(b); }
// Arithmetic with int32_t
inline bfloat16 operator+(bfloat16 a, int32_t b) { return a + static_cast<bfloat16>(b); }
inline bfloat16 operator-(bfloat16 a, int32_t b) { return a - static_cast<bfloat16>(b); }
inline bfloat16 operator*(bfloat16 a, int32_t b) { return a * static_cast<bfloat16>(b); }
inline bfloat16 operator/(bfloat16 a, int32_t b) { return a / static_cast<bfloat16>(b); }
inline bfloat16 operator+(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) + b; }
inline bfloat16 operator-(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) - b; }
inline bfloat16 operator*(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) * b; }
inline bfloat16 operator/(int32_t a, bfloat16 b) { return static_cast<bfloat16>(a) / b; }
// Arithmetic with int64_t
inline bfloat16 operator+(bfloat16 a, int64_t b) { return a + static_cast<bfloat16>(b); }
inline bfloat16 operator-(bfloat16 a, int64_t b) { return a - static_cast<bfloat16>(b); }
inline bfloat16 operator*(bfloat16 a, int64_t b) { return a * static_cast<bfloat16>(b); }
inline bfloat16 operator/(bfloat16 a, int64_t b) { return a / static_cast<bfloat16>(b); }
inline bfloat16 operator+(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) + b; }
inline bfloat16 operator-(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) - b; }
inline bfloat16 operator*(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) * b; }
inline bfloat16 operator/(int64_t a, bfloat16 b) { return static_cast<bfloat16>(a) / b; }
// Comparison operators
inline bool operator>(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) > static_cast<float>(rhs);
}
inline bool operator>=(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) >= static_cast<float>(rhs);
}
inline bool operator<(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) < static_cast<float>(rhs);
}
inline bool operator<=(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) <= static_cast<float>(rhs);
}
inline bool operator==(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) == static_cast<float>(rhs);
}
inline bool operator!=(bfloat16& lhs, bfloat16& rhs) {
return static_cast<float>(lhs) != static_cast<float>(rhs);
}
} // namespace oneflow
namespace std {
inline bool isnan(const oneflow::bfloat16& value) { return (value.x & 0x7FFFU) > 0x07F80U; }
inline bool isinf(const oneflow::bfloat16& value) { return value.x == 0x07F80U; }
inline bool isfinite(const oneflow::bfloat16& value) { return !isinf(value) && !isnan(value); }
template<>
class numeric_limits<oneflow::bfloat16> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_specialized = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
static constexpr auto has_denorm_loss = numeric_limits<float>::has_denorm_loss;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 8;
static constexpr int digits10 = 2;
static constexpr int max_digits10 = 4;
static constexpr int radix = 2;
static constexpr int min_exponent = -125;
static constexpr int min_exponent10 = -37;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before = numeric_limits<float>::tinyness_before;
static constexpr oneflow::bfloat16 min() {
return oneflow::bfloat16(0x0080U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 lowest() {
return oneflow::bfloat16(0xFF7FU, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 max() {
return oneflow::bfloat16(0x7F7FU, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 epsilon() {
return oneflow::bfloat16(0x3C00U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 round_error() {
return oneflow::bfloat16(0x3F00U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 infinity() {
return oneflow::bfloat16(0x7F80U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 quiet_NaN() {
return oneflow::bfloat16(0x7FC0U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 signaling_NaN() {
return oneflow::bfloat16(0x7F80U, oneflow::bfloat16::from_bits());
}
static constexpr oneflow::bfloat16 denorm_min() {
return oneflow::bfloat16(0x0001U, oneflow::bfloat16::from_bits());
}
};
} // namespace std
#endif // ONEFLOW_CORE_COMMON_BFLOAT16_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_
#define ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_
#include "oneflow/core/common/bfloat16.h"
namespace std {
// reference: pytorch/c10/util/BFloat16-math.h
// https://github.com/pytorch/pytorch/blob/release/1.12/c10/util/BFloat16-math.h
inline oneflow::bfloat16 acos(oneflow::bfloat16 a) { return std::acos(static_cast<float>(a)); }
inline oneflow::bfloat16 asin(oneflow::bfloat16 a) { return std::asin(static_cast<float>(a)); }
inline oneflow::bfloat16 atan(oneflow::bfloat16 a) { return std::atan(static_cast<float>(a)); }
inline oneflow::bfloat16 erf(oneflow::bfloat16 a) { return std::erf(static_cast<float>(a)); }
inline oneflow::bfloat16 erfc(oneflow::bfloat16 a) { return std::erfc(static_cast<float>(a)); }
inline oneflow::bfloat16 exp(oneflow::bfloat16 a) { return std::exp(static_cast<float>(a)); }
inline oneflow::bfloat16 expm1(oneflow::bfloat16 a) { return std::expm1(static_cast<float>(a)); }
inline oneflow::bfloat16 log(oneflow::bfloat16 a) { return std::log(static_cast<float>(a)); }
inline oneflow::bfloat16 log10(oneflow::bfloat16 a) { return std::log10(static_cast<float>(a)); }
inline oneflow::bfloat16 log1p(oneflow::bfloat16 a) { return std::log1p(static_cast<float>(a)); }
inline oneflow::bfloat16 log2(oneflow::bfloat16 a) { return std::log2(static_cast<float>(a)); }
inline oneflow::bfloat16 ceil(oneflow::bfloat16 a) { return std::ceil(static_cast<float>(a)); }
inline oneflow::bfloat16 cos(oneflow::bfloat16 a) { return std::cos(static_cast<float>(a)); }
inline oneflow::bfloat16 floor(oneflow::bfloat16 a) { return std::floor(static_cast<float>(a)); }
inline oneflow::bfloat16 nearbyint(oneflow::bfloat16 a) {
return std::nearbyint(static_cast<float>(a));
}
inline oneflow::bfloat16 sin(oneflow::bfloat16 a) { return std::sin(static_cast<float>(a)); }
inline oneflow::bfloat16 tan(oneflow::bfloat16 a) { return std::tan(static_cast<float>(a)); }
inline oneflow::bfloat16 sinh(oneflow::bfloat16 a) { return std::sinh(static_cast<float>(a)); }
inline oneflow::bfloat16 cosh(oneflow::bfloat16 a) { return std::cosh(static_cast<float>(a)); }
inline oneflow::bfloat16 tanh(oneflow::bfloat16 a) { return std::tanh(static_cast<float>(a)); }
inline oneflow::bfloat16 trunc(oneflow::bfloat16 a) { return std::trunc(static_cast<float>(a)); }
inline oneflow::bfloat16 lgamma(oneflow::bfloat16 a) { return std::lgamma(static_cast<float>(a)); }
inline oneflow::bfloat16 sqrt(oneflow::bfloat16 a) { return std::sqrt(static_cast<float>(a)); }
inline oneflow::bfloat16 rsqrt(oneflow::bfloat16 a) {
return 1.0 / std::sqrt(static_cast<float>(a));
}
inline oneflow::bfloat16 abs(oneflow::bfloat16 a) { return std::abs(static_cast<float>(a)); }
inline oneflow::bfloat16 pow(oneflow::bfloat16 a, double b) {
return std::pow(static_cast<float>(a), b);
}
inline oneflow::bfloat16 pow(oneflow::bfloat16 a, oneflow::bfloat16 b) {
return std::pow(static_cast<float>(a), static_cast<float>(b));
}
inline oneflow::bfloat16 fmod(oneflow::bfloat16 a, oneflow::bfloat16 b) {
return std::fmod(static_cast<float>(a), static_cast<float>(b));
}
} // namespace std
#endif // ONEFLOW_CORE_COMMON_BFLOAT16_MATH_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "gtest/gtest.h"
#include "oneflow/core/common/bfloat16.h"
#include "oneflow/core/common/bfloat16_math.h"
namespace oneflow {
namespace test {
float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) {
// reference: pytorch/c10/test/util/bfloat16_test.cpp
// https://github.com/pytorch/pytorch/blob/release/1.12/c10/test/util/bfloat16_test.cpp
uint32_t bytes = 0;
bytes |= sign;
bytes <<= 8;
bytes |= exponent;
bytes <<= 23;
bytes |= fraction;
float res = NAN;
std::memcpy(&res, &bytes, sizeof(res));
return res;
}
TEST(BFLOAT16MATH, Add) {
// 6.25
float input = float_from_bytes(0, 0, 0x40C80000U);
// 7.25
float expected = float_from_bytes(0, 0, 0x40E80000U);
bfloat16 b(input);
b = b + 1;
float res = static_cast<float>(b);
EXPECT_EQ(res, expected);
}
TEST(BFLOAT16MATH, Sub) {
// 7.25
float input = float_from_bytes(0, 0, 0x40E80000U);
// 6.25
float expected = float_from_bytes(0, 0, 0x40C80000U);
bfloat16 b(input);
b = b - 1;
float res = static_cast<float>(b);
EXPECT_EQ(res, expected);
}
TEST(BFLOAT16MATH, Mul) {
// 3.125
float input = float_from_bytes(0, 0, 0x40480000U);
// 6.25
float expected = float_from_bytes(0, 0, 0x40C80000U);
bfloat16 b(input);
b = b * 2;
float res = static_cast<float>(b);
EXPECT_EQ(res, expected);
}
TEST(BFLOAT16MATH, Div) {
// 6.25
float input = float_from_bytes(0, 0, 0x40C80000U);
// 3.125
float expected = float_from_bytes(0, 0, 0x40480000U);
bfloat16 b(input);
b = b / 2;
float res = static_cast<float>(b);
EXPECT_EQ(res, expected);
}
TEST(BFLOAT16MATH, Log2) {
// 16
float input = float_from_bytes(0, 0, 0x41800000U);
// 4
float expected = float_from_bytes(0, 0, 0x40800000U);
bfloat16 b(input);
b = std::log2(b);
float res = static_cast<float>(b);
EXPECT_EQ(res, expected);
}
TEST(BFLOAT16MATH, Log10) {
// 100
float input = float_from_bytes(0, 0, 0x42C80000U);
// 2
float expected = float_from_bytes(0, 0, 0x40000000U);
bfloat16 b(input);
b = std::log10(b);
float res = static_cast<float>(b);
EXPECT_EQ(res, expected);
}
TEST(BFLOAT16MATH, Sqrt) {
// 25
float input = float_from_bytes(0, 0, 0x41C80000U);
// 5
float expected = float_from_bytes(0, 0, 0x40A00000U);
bfloat16 b(input);
b = std::sqrt(b);
float res = static_cast<float>(b);
EXPECT_EQ(res, expected);
}
} // namespace test
} // namespace oneflow
...@@ -18,27 +18,23 @@ limitations under the License. ...@@ -18,27 +18,23 @@ limitations under the License.
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#ifdef WITH_CUDA
#include <cuda_fp16.h>
#endif // WITH_CUDA
#ifdef WITH_ROCM
#include <hip/hip_fp16.h>
#endif // WITH_ROCM
#include "oneflow/core/common/cblas.h" #include "oneflow/core/common/cblas.h"
#include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/preprocessor.h"
namespace oneflow { namespace oneflow {
#define BLAS_NAME_SEQ \ #define BLAS_NAME_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dot) \ OF_PP_MAKE_TUPLE_SEQ(dot) \
OF_PP_MAKE_TUPLE_SEQ(swap) \ OF_PP_MAKE_TUPLE_SEQ(swap) \
OF_PP_MAKE_TUPLE_SEQ(copy) \ OF_PP_MAKE_TUPLE_SEQ(copy) \
OF_PP_MAKE_TUPLE_SEQ(axpy) \ OF_PP_MAKE_TUPLE_SEQ(axpy) \
OF_PP_MAKE_TUPLE_SEQ(scal) \ OF_PP_MAKE_TUPLE_SEQ(scal) \
OF_PP_MAKE_TUPLE_SEQ(gemv) \ OF_PP_MAKE_TUPLE_SEQ(gemv) \
OF_PP_MAKE_TUPLE_SEQ(gemm) \ OF_PP_MAKE_TUPLE_SEQ(gemm) \
OF_PP_MAKE_TUPLE_SEQ(gemmBatched) \ OF_PP_MAKE_TUPLE_SEQ(gemmBatched) \
OF_PP_MAKE_TUPLE_SEQ(gemmStridedBatched) OF_PP_MAKE_TUPLE_SEQ(gemmStridedBatched) \
OF_PP_MAKE_TUPLE_SEQ(getrfBatched) \
OF_PP_MAKE_TUPLE_SEQ(getriBatched)
#define CBLAS_TEMPLATE(name) \ #define CBLAS_TEMPLATE(name) \
template<typename T, typename... Args> \ template<typename T, typename... Args> \
......
...@@ -17,7 +17,7 @@ limitations under the License. ...@@ -17,7 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_ #define ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_
#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/blocking_counter.h" #include "oneflow/core/common/notifier.h"
#include "oneflow/core/common/spin_counter.h" #include "oneflow/core/common/spin_counter.h"
namespace oneflow { namespace oneflow {
...@@ -26,20 +26,22 @@ class BlockingThenBusy final { ...@@ -26,20 +26,22 @@ class BlockingThenBusy final {
public: public:
BlockingThenBusy(const BlockingThenBusy&) = delete; BlockingThenBusy(const BlockingThenBusy&) = delete;
BlockingThenBusy(BlockingThenBusy&&) = delete; BlockingThenBusy(BlockingThenBusy&&) = delete;
BlockingThenBusy() = delete; constexpr static int kCnt = 1;
explicit BlockingThenBusy(int cnt) : blocking_counter_(cnt), spin_counter_(cnt) {} BlockingThenBusy() : notifier_(), spin_counter_(kCnt) {}
BlockingCounter* mut_blocking_counter() { return &blocking_counter_; } Notifier* mut_notifier() { return &notifier_; }
SpinCounter* mut_spin_counter() { return &spin_counter_; } SpinCounter* mut_spin_counter() { return &spin_counter_; }
void Reset() { mut_spin_counter()->Reset(kCnt); }
Maybe<void> WaitUntilCntEqualZero(const std::function<Maybe<bool>()>& StopAfterTimeout) { Maybe<void> WaitUntilCntEqualZero(const std::function<Maybe<bool>()>& StopAfterTimeout) {
JUST(blocking_counter_.WaitUntilCntEqualZero(StopAfterTimeout)); JUST(notifier_.TimedWaitAndClearNotifiedCnt(StopAfterTimeout));
JUST(spin_counter_.WaitUntilCntEqualZero()); JUST(spin_counter_.WaitUntilCntEqualZero());
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
private: private:
BlockingCounter blocking_counter_; Notifier notifier_;
SpinCounter spin_counter_; SpinCounter spin_counter_;
}; };
......
...@@ -70,23 +70,13 @@ inline std::string GetOutputCriticalSectionCallbackBufferName(const std::string& ...@@ -70,23 +70,13 @@ inline std::string GetOutputCriticalSectionCallbackBufferName(const std::string&
return prefix + job_name; return prefix + job_name;
} }
inline std::string GetForeignInputBufferName(const std::string& job_name) {
static const std::string prefix = "ForeignInput-";
return prefix + job_name;
}
inline std::string GetForeignOutputBufferName(const std::string& job_name) {
static const std::string prefix = "ForeignOutput-";
return prefix + job_name;
}
inline std::string GetInputBufferName(const std::string& job_name, const std::string& op_name) { inline std::string GetInputBufferName(const std::string& job_name, const std::string& op_name) {
static const std::string prefix = "ForeignInput-"; static const std::string prefix = "Input-";
return prefix + job_name + "-" + op_name; return prefix + job_name + "-" + op_name;
} }
inline std::string GetOutputBufferName(const std::string& job_name, const std::string& op_name) { inline std::string GetOutputBufferName(const std::string& job_name, const std::string& op_name) {
static const std::string prefix = "ForeignOutput-"; static const std::string prefix = "Output-";
return prefix + job_name + "-" + op_name; return prefix + job_name + "-" + op_name;
} }
......
...@@ -24,6 +24,7 @@ static const int64_t kInvalidSessionId = -1; ...@@ -24,6 +24,7 @@ static const int64_t kInvalidSessionId = -1;
static const std::string kNoPassTag = ""; static const std::string kNoPassTag = "";
static const std::string kMainOp = "main_op"; static const std::string kMainOp = "main_op";
static const int64_t kMaxSplitAxis = 6; static const int64_t kMaxSplitAxis = 6;
constexpr size_t kMaxNumDims = 8;
static const std::string kAsymmetricCodeErrorMsg = static const std::string kAsymmetricCodeErrorMsg =
"Maybe executing different code in different ranks, please check if the code is branched and " "Maybe executing different code in different ranks, please check if the code is branched and "
"operates on the global tensor."; "operates on the global tensor.";
......
...@@ -82,18 +82,6 @@ std::string Join(const T& con, const std::string& delimiter) { ...@@ -82,18 +82,6 @@ std::string Join(const T& con, const std::string& delimiter) {
return os.str(); return os.str();
} }
template<typename T>
using SmallSet = std::vector<T>;
template<typename T>
std::pair<typename SmallSet<T>::iterator, bool> SmallSetInsert(SmallSet<T>* vec, const T& elem) {
for (auto iter = vec->begin(); iter != vec->end(); ++iter) {
if (*iter == elem) { return std::make_pair(iter, false); }
}
vec->push_back(elem);
return std::make_pair(--vec->end(), true);
}
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_ #endif // ONEFLOW_CORE_COMMON_CONTAINER_UTIL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_TIME_UTIL_H_
#define ONEFLOW_CORE_COMMON_TIME_UTIL_H_
#include <chrono>
#include <sstream>
#include <string>
#include "nlohmann/json.hpp"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/mem_util.h"
#include "oneflow/core/job/utils/progress_bar.h"
namespace oneflow {
template<typename DurationT>
struct Duration {
static const std::string& Repr() {
static const std::string repr = "";
return repr;
}
};
#define DEFINE_DURATION_TRAIT(time_type) \
template<> \
struct Duration<typename std::chrono::time_type> { \
static const std::string& Repr() { \
static const std::string repr = #time_type; \
return repr; \
} \
};
DEFINE_DURATION_TRAIT(nanoseconds)
DEFINE_DURATION_TRAIT(microseconds)
DEFINE_DURATION_TRAIT(milliseconds)
DEFINE_DURATION_TRAIT(seconds)
DEFINE_DURATION_TRAIT(minutes)
DEFINE_DURATION_TRAIT(hours)
#undef DEFINE_DURATION_TRAIT
template<class Resolution = std::chrono::seconds>
class CostCounter final {
public:
OF_DISALLOW_COPY_AND_MOVE(CostCounter);
explicit CostCounter(bool with_log = true, bool with_mem = false)
: with_log_(with_log), with_mem_(with_mem) {}
~CostCounter() = default;
void Count(const std::string& log_prefix = "", int v_log_level = 0, bool log_progress = false);
private:
using Clock = std::conditional_t<std::chrono::high_resolution_clock::is_steady,
std::chrono::high_resolution_clock, std::chrono::steady_clock>;
Clock::time_point start_{Clock::now()};
bool with_log_{false};
bool with_mem_{false};
};
template<class Resolution>
void CostCounter<Resolution>::Count(const std::string& log_prefix, int v_log_level,
bool log_progress) {
if (log_progress) { CHECK_JUST(LogProgress(log_prefix)); }
const auto end = Clock::now();
if (FLAGS_minloglevel <= 0 && VLOG_IS_ON(v_log_level) && with_log_ && v_log_level >= 0) {
// only do time/mem count and log when glog level is INFO and VLOG level is matched.
auto dur = std::chrono::duration_cast<Resolution>(end - start_).count();
nlohmann::json json_log;
json_log["loc"] = log_prefix;
json_log["time_cost"] = std::to_string(dur) + " " + Duration<Resolution>::Repr();
if (with_mem_) {
#ifdef __linux__
double vm = 0, rss = 0;
ProcessMemUsage(&vm, &rss);
json_log["mem_rss"] = std::to_string(rss) + " MB";
#endif // __linux__
}
if (v_log_level == 0) {
LOG(INFO) << "[count log]" << json_log.dump();
} else {
VLOG(v_log_level) << "[count log]" << json_log.dump();
}
}
start_ = end;
return;
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_TIME_UTIL_H_
...@@ -46,6 +46,15 @@ bool IsFloatingDataType(DataType data_type) { ...@@ -46,6 +46,15 @@ bool IsFloatingDataType(DataType data_type) {
} }
#undef FLOATING_CASE #undef FLOATING_CASE
} }
bool IsHalfDataType(DataType data_type) {
switch (data_type) {
#define HALF_CASE(type_cpp, type_proto) \
case type_proto: return true;
OF_PP_FOR_EACH_TUPLE(HALF_CASE, FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ)
default: return false;
}
#undef HALF_CASE
}
bool IsPODDataType(DataType data_type) { bool IsPODDataType(DataType data_type) {
switch (data_type) { switch (data_type) {
#define POD_CASE(type_cpp, type_proto) \ #define POD_CASE(type_cpp, type_proto) \
...@@ -77,7 +86,8 @@ bool IsSupportRequireGradDataType(DataType data_type) { ...@@ -77,7 +86,8 @@ bool IsSupportRequireGradDataType(DataType data_type) {
switch (data_type) { switch (data_type) {
#define REQUIRE_GRAD_CASE(type_cpp, type_proto) \ #define REQUIRE_GRAD_CASE(type_cpp, type_proto) \
case type_proto: return true; case type_proto: return true;
OF_PP_FOR_EACH_TUPLE(REQUIRE_GRAD_CASE, FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ) OF_PP_FOR_EACH_TUPLE(REQUIRE_GRAD_CASE,
FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ)
default: return false; default: return false;
} }
#undef REQUIRE_GRAD_CASE #undef REQUIRE_GRAD_CASE
...@@ -130,6 +140,58 @@ size_t GetSizeOfDataType(DataType data_type) { ...@@ -130,6 +140,58 @@ size_t GetSizeOfDataType(DataType data_type) {
} }
} }
int64_t GetIntMaxVal(DataType datatype) {
#define SWITCH_INT_TYPE(cpp_type, of_datatype) \
case of_datatype: return static_cast<int64_t>(GetMaxVal<DataTypeToType<of_datatype>>());
switch (datatype) {
OF_PP_FOR_EACH_TUPLE(SWITCH_INT_TYPE, INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ)
default:
LOG(FATAL) << "invalid data_type: " << DataType_Name(datatype)
<< " for GetIntMaxVal(DataType)";
}
#undef SWITCH_INT_TYPE
}
int64_t GetIntMinVal(DataType datatype) {
#define SWITCH_INT_TYPE(cpp_type, of_datatype) \
case of_datatype: return static_cast<int64_t>(GetMinVal<DataTypeToType<of_datatype>>());
switch (datatype) {
OF_PP_FOR_EACH_TUPLE(SWITCH_INT_TYPE, INT_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ)
default:
LOG(FATAL) << "invalid data_type: " << DataType_Name(datatype)
<< " for GetIntMinVal(DataType)";
}
#undef SWITCH_INT_TYPE
}
double GetFloatMaxVal(DataType datatype) {
#define SWITCH_FLOAT_TYPE(cpp_type, of_datatype) \
case of_datatype: return static_cast<double>(GetMaxVal<DataTypeToType<of_datatype>>());
switch (datatype) {
OF_PP_FOR_EACH_TUPLE(SWITCH_FLOAT_TYPE, FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ)
default:
LOG(FATAL) << "invalid data_type: " << DataType_Name(datatype)
<< " for GetFloatMaxVal(DataType)";
}
#undef SWITCH_FLOAT_TYPE
}
double GetFloatMinVal(DataType datatype) {
#define SWITCH_FLOAT_TYPE(cpp_type, of_datatype) \
case of_datatype: return static_cast<double>(GetMinVal<DataTypeToType<of_datatype>>());
switch (datatype) {
OF_PP_FOR_EACH_TUPLE(SWITCH_FLOAT_TYPE, FLOATING_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ)
default:
LOG(FATAL) << "invalid data_type: " << DataType_Name(datatype)
<< " for GetFloatMinVal(DataType)";
}
#undef SWITCH_INT_TYPE
}
namespace { namespace {
void CheckDataType() { void CheckDataType() {
......
...@@ -20,11 +20,17 @@ limitations under the License. ...@@ -20,11 +20,17 @@ limitations under the License.
#include <type_traits> #include <type_traits>
#if defined(WITH_CUDA) #if defined(WITH_CUDA)
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda.h>
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
#endif #endif
#if defined(WITH_ROCM) #if defined(WITH_ROCM)
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#endif #endif
#include "oneflow/core/common/bfloat16.h"
#include "oneflow/core/common/bfloat16_math.h"
#include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/data_type_seq.h" #include "oneflow/core/common/data_type_seq.h"
#include "oneflow/core/record/record.pb.h" #include "oneflow/core/record/record.pb.h"
...@@ -34,8 +40,18 @@ limitations under the License. ...@@ -34,8 +40,18 @@ limitations under the License.
namespace oneflow { namespace oneflow {
template<>
struct IsScalarType<bfloat16> final {
static const bool value = true;
};
typedef half_float::half float16; typedef half_float::half float16;
template<>
struct IsScalarType<float16> final {
static const bool value = true;
};
template<typename T> template<typename T>
struct IsFloat16; struct IsFloat16;
...@@ -103,19 +119,23 @@ struct GetDataType<void> : std::integral_constant<DataType, DataType::kChar> {}; ...@@ -103,19 +119,23 @@ struct GetDataType<void> : std::integral_constant<DataType, DataType::kChar> {};
template<> \ template<> \
struct GetDataType<type_cpp> : std::integral_constant<DataType, type_proto> {}; \ struct GetDataType<type_cpp> : std::integral_constant<DataType, type_proto> {}; \
inline type_cpp GetTypeByDataType(std::integral_constant<DataType, type_proto>) { return {}; } inline type_cpp GetTypeByDataType(std::integral_constant<DataType, type_proto>) { return {}; }
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_DATA_TYPE, ALL_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ); OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_DATA_TYPE,
ALL_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ);
#undef SPECIALIZE_GET_DATA_TYPE #undef SPECIALIZE_GET_DATA_TYPE
template<typename T> template<typename T>
struct GetDataType<T, typename std::enable_if<IsFloat16<T>::value>::type> struct GetDataType<T, typename std::enable_if<IsFloat16<T>::value>::type>
: std::integral_constant<DataType, DataType::kFloat16> {}; : std::integral_constant<DataType, DataType::kFloat16> {};
#if CUDA_VERSION >= 11000
template<>
struct GetDataType<nv_bfloat16> : std::integral_constant<DataType, DataType::kBFloat16> {};
#endif
template<DataType type> template<DataType type>
using DataTypeToType = decltype(GetTypeByDataType(std::integral_constant<DataType, type>{})); using DataTypeToType = decltype(GetTypeByDataType(std::integral_constant<DataType, type>{}));
#if defined(__CUDACC__) #if defined(__CUDACC__) || defined(__HIPCC__)
#define OF_DEVICE_FUNC __device__ __host__ __forceinline__
#elif defined(__HIPCC__)
#define OF_DEVICE_FUNC __device__ __host__ __forceinline__ #define OF_DEVICE_FUNC __device__ __host__ __forceinline__
#else #else
#define OF_DEVICE_FUNC inline #define OF_DEVICE_FUNC inline
...@@ -240,7 +260,14 @@ struct DevDType<DeviceType::kCUDA, float16> { ...@@ -240,7 +260,14 @@ struct DevDType<DeviceType::kCUDA, float16> {
static_assert(sizeof(float16) == sizeof(half), "sizeof(float16) != sizeof(half)"); static_assert(sizeof(float16) == sizeof(half), "sizeof(float16) != sizeof(half)");
typedef half type; typedef half type;
}; };
#endif #if CUDA_VERSION >= 11000
template<>
struct DevDType<DeviceType::kCUDA, bfloat16> {
static_assert(sizeof(bfloat16) == sizeof(nv_bfloat16), "sizeof(bfloat16) != sizeof(nv_bfloat16)");
typedef nv_bfloat16 type;
};
#endif // CUDA_VERSION >= 11000
#endif // defined(WITH_CUDA)
#if defined(WITH_ROCM) #if defined(WITH_ROCM)
template<> template<>
...@@ -248,13 +275,21 @@ struct DevDType<DeviceType::kCUDA, float16> { ...@@ -248,13 +275,21 @@ struct DevDType<DeviceType::kCUDA, float16> {
static_assert(sizeof(float16) == sizeof(half), "sizeof(float16) != sizeof(half)"); static_assert(sizeof(float16) == sizeof(half), "sizeof(float16) != sizeof(half)");
typedef half type; typedef half type;
}; };
#endif // #if CUDA_VERSION >= 11000
// template<>
// struct DevDType<DeviceType::kCUDA, bfloat16> {
// static_assert(sizeof(bfloat16) == sizeof(nv_bfloat16), "sizeof(bfloat16) != sizeof(nv_bfloat16)");
// typedef nv_bfloat16 type;
// };
// #endif // CUDA_VERSION >= 11000
#endif // defined(WITH_ROCM)
// Func // Func
bool IsBoolDataType(DataType data_type); bool IsBoolDataType(DataType data_type);
bool IsIntegralDataType(DataType data_type); bool IsIntegralDataType(DataType data_type);
bool IsFloatingDataType(DataType data_type); bool IsFloatingDataType(DataType data_type);
bool IsHalfDataType(DataType data_type);
bool IsSupportRequireGradDataType(DataType data_type); bool IsSupportRequireGradDataType(DataType data_type);
bool IsPODDataType(DataType data_type); bool IsPODDataType(DataType data_type);
bool IsPODAndHalfDataType(DataType data_type); bool IsPODAndHalfDataType(DataType data_type);
...@@ -269,11 +304,16 @@ inline bool operator==(const OptInt64& lhs, const OptInt64& rhs) { ...@@ -269,11 +304,16 @@ inline bool operator==(const OptInt64& lhs, const OptInt64& rhs) {
template<typename T> template<typename T>
void CheckDataType(DataType data_type) { void CheckDataType(DataType data_type) {
LOG_IF(FATAL, (std::is_same<T, void>::value == false && std::is_same<T, char>::value == false && std::is_same<T, long>::value == false LOG_IF(FATAL, (std::is_same<T, void>::value == false && std::is_same<T, char>::value == false
&& data_type != DataType::kChar && data_type != GetDataType<T>::value)) && data_type != DataType::kChar && data_type != GetDataType<T>::value))
<< data_type << " " << GetDataType<T>::value; << data_type << " " << GetDataType<T>::value;
} }
int64_t GetIntMaxVal(DataType datatype);
int64_t GetIntMinVal(DataType datatype);
double GetFloatMaxVal(DataType datatype);
double GetFloatMinVal(DataType datatype);
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_DATA_TYPE_H_ #endif // ONEFLOW_CORE_COMMON_DATA_TYPE_H_
...@@ -352,11 +352,10 @@ struct ConvertIntInt<Out, In, false, true> { ...@@ -352,11 +352,10 @@ struct ConvertIntInt<Out, In, false, true> {
OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) { OF_DEVICE_FUNC static const Out ConvertSatNorm(In value) {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return cuda_round_helper(__saturatef(value * (1.0f / GetMaxVal<In>())) * GetMaxVal<Out>()); return cuda_round_helper(__saturatef(value * (1.0f / GetMaxVal<In>())) * GetMaxVal<Out>());
}
#else #else
return value < 0 ? 0 : ConvertNorm(value); return value < 0 ? 0 : ConvertNorm(value);
}
#endif #endif
}
}; };
// Converts between integral types // Converts between integral types
......
...@@ -43,7 +43,7 @@ limitations under the License. ...@@ -43,7 +43,7 @@ limitations under the License.
#define POD_DATA_TYPE_SEQ \ #define POD_DATA_TYPE_SEQ \
ARITHMETIC_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ ARITHMETIC_DATA_TYPE_SEQ CHAR_DATA_TYPE_SEQ UNSIGNED_INT_DATA_TYPE_SEQ BOOL_DATA_TYPE_SEQ
#define POD_AND_HALF_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ #define POD_AND_HALF_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ FLOAT16_DATA_TYPE_SEQ BFLOAT16_DATA_TYPE_SEQ
#define PB_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(OFRecord, DataType::kOFRecord) #define PB_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(OFRecord, DataType::kOFRecord)
#define ALL_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ PB_DATA_TYPE_SEQ #define ALL_DATA_TYPE_SEQ POD_DATA_TYPE_SEQ PB_DATA_TYPE_SEQ
...@@ -53,12 +53,20 @@ limitations under the License. ...@@ -53,12 +53,20 @@ limitations under the License.
#define FLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16) #define FLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)
#define BFLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16)
#if defined(WITH_CUDA) #if defined(WITH_CUDA)
#define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) #define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
#endif #if CUDA_VERSION >= 11000
#define NV_BFLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)
#endif // CUDA_VERSION >= 11000
#endif // defined(WITH_CUDA)
#if defined(WITH_ROCM) #if defined(WITH_ROCM)
#define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) #define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
// #if CUDA_VERSION >= 11000
// #define NV_BFLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16)
// #endif // CUDA_VERSION >= 11000
#endif #endif
#define IMAGE_DATA_TYPE_SEQ \ #define IMAGE_DATA_TYPE_SEQ \
......
...@@ -6,5 +6,4 @@ enum DeviceType { ...@@ -6,5 +6,4 @@ enum DeviceType {
kCPU = 1; kCPU = 1;
kCUDA = 2; kCUDA = 2;
kMockDevice = 3; // pseudo device for test. kMockDevice = 3; // pseudo device for test.
kROCm = 4;
} }
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_BOOTSTRAP_H_
#define ONEFLOW_CORE_COMMON_ENV_VAR_BOOTSTRAP_H_
#include "oneflow/core/common/env_var/env_var.h"
namespace oneflow {
DEFINE_ENV_INTEGER(ONEFLOW_RPC_BOOTSTRAP_SERVER_SLEEP_SECONDS, 20);
DEFINE_ENV_INTEGER(ONEFLOW_RPC_BOOTSTRAP_SERVER_MAX_RETRY_TIMES, 3);
DEFINE_ENV_INTEGER(ONEFLOW_RPC_CLIENT_SLEEP_SECONDS, 5);
DEFINE_ENV_INTEGER(ONEFLOW_RPC_CLIENT_MAX_RETRY_TIMES, 6);
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_BOOTSTRAP_H_
...@@ -25,6 +25,13 @@ DEFINE_ENV_BOOL(ONEFLOW_DEBUG, false); ...@@ -25,6 +25,13 @@ DEFINE_ENV_BOOL(ONEFLOW_DEBUG, false);
inline bool IsInDebugMode() { return EnvBool<ONEFLOW_DEBUG_MODE>() || EnvBool<ONEFLOW_DEBUG>(); } inline bool IsInDebugMode() { return EnvBool<ONEFLOW_DEBUG_MODE>() || EnvBool<ONEFLOW_DEBUG>(); }
DEFINE_ENV_BOOL(ENABLE_LOGICAL_CHAIN, false);
inline bool EnableLogicalChain() { return EnvBool<ENABLE_LOGICAL_CHAIN>(); }
inline bool IsPythonStackGetterEnabled() {
return ParseBooleanFromEnv("ONEFLOW_PYTHON_STACK_GETTER", IsInDebugMode());
}
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_ #endif // ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_
#define ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_
#include "oneflow/core/common/env_var/env_var.h"
namespace oneflow {
// NOTE: use env variable 'ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE' indicate whether the
// use infer cache in naive local op interpret.
DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_EAGER_ENABLE_LOCAL_INFER_CACHE, true);
// NOTE: use env variable 'ONEFLOW_EAGER_TENSOR_INFER_CACHE_SIZE' indicate the size of
// infer cache in op interpret.
DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_EAGER_TENSOR_INFER_CACHE_SIZE, 128 * 1024);
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_EAGER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_STREAM_H_
#define ONEFLOW_CORE_COMMON_ENV_VAR_STREAM_H_
#include "oneflow/core/common/env_var/env_var.h"
namespace oneflow {
DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_DEVICE_STREAM_MAX_SIZE, 16);
DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_STREAM_ENABLE_H2D_STREAM, false);
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_STREAM_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment