Commit 6046d8fb authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

dtk23.04

parent a715222c
...@@ -334,16 +334,16 @@ elseif(WIN32) ...@@ -334,16 +334,16 @@ elseif(WIN32)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow")
endif() endif()
if (BUILD_ROCM) # if (BUILD_ROCM)
# AMD compiler fails to compile these three files with '-O1/2/3'. # # AMD compiler fails to compile these three files with '-O1/2/3'.
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>, # # The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# so '-O0' will override '-O1/2/3'. # # so '-O0' will override '-O1/2/3'.
set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp # set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp # ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp # ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp # # ${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
PROPERTIES COMPILE_OPTIONS "-O0") # PROPERTIES COMPILE_OPTIONS "-O0")
endif() # endif()
if(BUILD_CUDA) if(BUILD_CUDA)
string(JOIN "," CUDA_REAL_ARCHS ${CUDA_REAL_ARCHS_LIST}) string(JOIN "," CUDA_REAL_ARCHS ${CUDA_REAL_ARCHS_LIST})
......
...@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n); ...@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n);
template<typename T> template<typename T>
OF_DEVICE_FUNC T DeviceMin(T a, T b) { OF_DEVICE_FUNC T DeviceMin(T a, T b) {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return a < b ? a : b; return a < b ? a : b;
#else #else
return std::min(a, b); return std::min(a, b);
...@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) { ...@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) {
template<typename T> template<typename T>
OF_DEVICE_FUNC T DeviceMax(T a, T b) { OF_DEVICE_FUNC T DeviceMax(T a, T b) {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return a > b ? a : b; return a > b ? a : b;
#else #else
return std::max(a, b); return std::max(a, b);
......
...@@ -54,7 +54,7 @@ struct CastCASImpl { ...@@ -54,7 +54,7 @@ struct CastCASImpl {
} }
}; };
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__)) #if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__)) || defined(WITH_ROCM)
template<typename T> template<typename T>
struct CastCASImpl<T, unsigned short int> { struct CastCASImpl<T, unsigned short int> {
......
...@@ -308,7 +308,8 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t ...@@ -308,7 +308,8 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
count_shared[wid] = warp_count; count_shared[wid] = warp_count;
} }
__syncthreads(); __syncthreads();
if (wid == 0) {
#ifdef WITH_ROCM
if (threadIdx.x < blockDim.x / kWarpSize) { if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid]; warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid]; warp_m2 = m2_shared[lid];
...@@ -318,10 +319,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t ...@@ -318,10 +319,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
warp_m2 = static_cast<T>(0); warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0); warp_count = static_cast<T>(0);
} }
#ifdef WITH_ROCM __syncthreads();
__syncthreads();
if (wid == 0) {
#else #else
__syncwarp(); if (wid == 0) {
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
warp_count = count_shared[lid];
} else {
warp_mean = static_cast<T>(0);
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
__syncwarp();
#endif #endif
T block_mean = 0; T block_mean = 0;
T block_m2 = 0; T block_m2 = 0;
...@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO ...@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO
const double epsilon, ComputeType* mean, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) { ComputeType* inv_variance) {
constexpr int block_size = 128; constexpr int block_size = 128;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32; constexpr int waves = 32;
#endif
static_assert(block_size % thread_group_width == 0, ""); static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width; constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block); dim3 block_dim(thread_group_width, thread_groups_per_block);
...@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar ...@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \ stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \ } \
} }
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32) DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF #undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \ #define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \ else if (cols <= (max_col)*kWarpSize) { \
...@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar ...@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \ stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \ } \
} }
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32) DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF #undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \ #define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \ else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
...@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD ...@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD
const double epsilon, ComputeType* mean, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) { ComputeType* inv_variance) {
constexpr int block_size = 1024; constexpr int block_size = 1024;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32; constexpr int waves = 32;
#endif
int grid_dim_x; int grid_dim_x;
{ {
GPU(Error_t) err = GPU(Error_t) err =
...@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa ...@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa
const ComputeType* inv_variance, const int64_t rows, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) { const int64_t cols) {
constexpr int block_size = 128; constexpr int block_size = 128;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32; constexpr int waves = 32;
#endif
static_assert(block_size % thread_group_width == 0, ""); static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width; constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block); dim3 block_dim(thread_group_width, thread_groups_per_block);
...@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra ...@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \ stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \ } \
} }
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32) DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF #undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \ #define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \ else if (cols <= (max_col)*kWarpSize) { \
......
...@@ -252,15 +252,21 @@ struct SetContext { ...@@ -252,15 +252,21 @@ struct SetContext {
const int lane_age = ages[thread_ctx.lane_id]; const int lane_age = ages[thread_ctx.lane_id];
const bool lane_hit = (lane_key == key && lane_age != 0); const bool lane_hit = (lane_key == key && lane_age != 0);
#ifdef WITH_ROCM #ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_hit); const unsigned long long int hit_mask = __ballot(lane_hit);
if (hit_mask != 0) {
return __ffsll(static_cast<unsigned long long int>(hit_mask)) - 1;
} else {
return -1;
}
#else #else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_hit); const unsigned hit_mask = __ballot_sync(kFullMask, lane_hit);
#endif
if (hit_mask != 0) { if (hit_mask != 0) {
return __ffs(static_cast<int>(hit_mask)) - 1; return __ffs(static_cast<int>(hit_mask)) - 1;
} else { } else {
return -1; return -1;
} }
#endif
} }
__device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx, __device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx,
...@@ -277,15 +283,16 @@ struct SetContext { ...@@ -277,15 +283,16 @@ struct SetContext {
const Key lane_key = keys[thread_ctx.lane_id]; const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id]; int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM #ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_key == key && lane_age != 0); const unsigned long long int hit_mask = __ballot(lane_key == key && lane_age != 0);
#else #else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_key == key && lane_age != 0); const unsigned hit_mask = __ballot_sync(kFullMask, lane_key == key && lane_age != 0);
#endif #endif
if (hit_mask != 0) { if (hit_mask != 0) {
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
#ifdef WITH_ROCM #ifdef WITH_ROCM
insert_way = __ffsll(static_cast<unsigned long long int>(hit_mask)) - 1;
const int insert_way_age = __shfl(lane_age, insert_way); const int insert_way_age = __shfl(lane_age, insert_way);
#else #else
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
const int insert_way_age = __shfl_sync(kFullMask, lane_age, insert_way); const int insert_way_age = __shfl_sync(kFullMask, lane_age, insert_way);
#endif #endif
if (lane_age > insert_way_age) { if (lane_age > insert_way_age) {
...@@ -301,12 +308,16 @@ __syncwarp(); ...@@ -301,12 +308,16 @@ __syncwarp();
} }
if (insert_way == -1) { if (insert_way == -1) {
#ifdef WITH_ROCM #ifdef WITH_ROCM
const unsigned valid_mask = __ballot(lane_age != 0); const unsigned long long int valid_mask = __ballot(lane_age != 0);
#else #else
const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0); const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0);
#endif #endif
if (valid_mask != kFullMask) { if (valid_mask != kFullMask) {
#ifdef WITH_ROCM
insert_way = __popcll(static_cast<unsigned long long int>(valid_mask));
#else
insert_way = __popc(static_cast<int>(valid_mask)); insert_way = __popc(static_cast<int>(valid_mask));
#endif
if (lane_age > 0) { if (lane_age > 0) {
lane_age -= 1; lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) { } else if (thread_ctx.lane_id == insert_way) {
...@@ -329,7 +340,8 @@ __syncwarp(); ...@@ -329,7 +340,8 @@ __syncwarp();
const Key lane_key = keys[thread_ctx.lane_id]; const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id]; int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM #ifdef WITH_ROCM
const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1; unsigned long long int valid_mask_tmp = __ballot(lane_age == 1);
const int insert_way = __ffsll(valid_mask_tmp) - 1;
#else #else
const int insert_way = __ffs(__ballot_sync(kFullMask, lane_age == 1)) - 1; const int insert_way = __ffs(__ballot_sync(kFullMask, lane_age == 1)) - 1;
#endif #endif
...@@ -594,7 +606,8 @@ __syncwarp(); ...@@ -594,7 +606,8 @@ __syncwarp();
warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key; warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;
warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age; warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;
#ifdef WITH_ROCM #ifdef WITH_ROCM
const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0))); unsigned long long int valid_mask_tmp = __ballot(lane_age != 0);
const int key_count = __popcll(valid_mask_tmp);
#else #else
const int key_count = __popc(__ballot_sync(kFullMask, lane_age != 0)); const int key_count = __popc(__ballot_sync(kFullMask, lane_age != 0));
#endif #endif
......
...@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/include/primitive//broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/cuda/cuda_stream.h"
......
...@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> { ...@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
OF_DEVICE_FUNC half operator()(half src) const { OF_DEVICE_FUNC half operator()(half src) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const float tanh_in = const float tanh_in =
__half2float(__float2half_rn(alpha) * (src + __float2half_rn(beta) * src * src * src)); __half2float(__float2half_rn(alpha) * (src + __float2half_rn(beta) * src * src * src));
const float tanh_out = unary_functor_internal::TanhApprox(tanh_in); const float tanh_out = unary_functor_internal::TanhApprox(tanh_in);
...@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> { ...@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
} }
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* dst, const half* src) const { __device__ void Apply2(half* dst, const half* src) const {
const half2 src2 = *(reinterpret_cast<const half2*>(src)); const half2 src2 = *(reinterpret_cast<const half2*>(src));
const float2 tanh_in = __half22float2(__hmul2( const float2 tanh_in = __half22float2(__hmul2(
......
...@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() { ...@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() {
"reduce_sum_like", "reduce_sum_like",
"layer_norm_grad", "layer_norm_grad",
"layer_norm", "layer_norm",
"layer_norm_param_grad" "layer_norm_param_grad",
};
return black_list;
}
const AMPList& AutoMixedPrecisionLists::GrayList() { "add_n",
static AMPList gray_list = {"add_n",
"tf_avg_pool_1d", "tf_avg_pool_1d",
"tf_avg_pool_1d_grad", "tf_avg_pool_1d_grad",
"tf_avg_pool_2d", "tf_avg_pool_2d",
"tf_avg_pool_2d_grad", "tf_avg_pool_2d_grad",
"tf_avg_pool_3d", "tf_avg_pool_3d",
"tf_avg_pool_3d_grad", "tf_avg_pool_3d_grad",
"avg_pool_1d", // "avg_pool_1d",
"avg_pool_1d_grad", // "avg_pool_1d_grad",
"avg_pool_2d", // "avg_pool_2d",
"avg_pool_2d_grad", // "avg_pool_2d_grad",
"avg_pool_3d", // "avg_pool_3d",
"avg_pool_3d_grad", // "avg_pool_3d_grad",
"bias_add", "bias_add",
"sigmoid_grad", "sigmoid_grad",
"tanh", "tanh",
...@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() { ...@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"group_norm_grad", "group_norm_grad",
"silu", "silu",
"silu_grad", "silu_grad",
"fused_weighted_sum"}; "fused_weighted_sum",
return gray_list;
}
const AMPList& AutoMixedPrecisionLists::ClearList() { "broadcast_like",
// TODO(niuchong): tuple_identity
static AMPList clear_list = {"broadcast_like",
"gather", "gather",
"gather_nd", "gather_nd",
"scatter_nd", "scatter_nd",
...@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() { ...@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"tf_max_pool_2d_grad", "tf_max_pool_2d_grad",
"tf_max_pool_3d", "tf_max_pool_3d",
"tf_max_pool_3d_grad", "tf_max_pool_3d_grad",
"max_pool_1d", // "max_pool_1d",
"max_pool_1d_grad", // "max_pool_1d_grad",
"max_pool_2d", // "max_pool_2d",
"max_pool_2d_grad", // "max_pool_2d_grad",
"max_pool_3d", // "max_pool_3d",
"max_pool_3d_grad", // "max_pool_3d_grad",
"reshape", "reshape",
"reshape_like", "reshape_like",
"relu", "relu",
...@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() { ...@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"pinned_identity", "pinned_identity",
"to_contiguous", "to_contiguous",
"copy", "copy",
"upsample_nearest_2d"}; "upsample_nearest_2d"
};
return black_list;
}
const AMPList& AutoMixedPrecisionLists::GrayList() {
static AMPList gray_list = {
// "add_n",
// "tf_avg_pool_1d",
// "tf_avg_pool_1d_grad",
// "tf_avg_pool_2d",
// "tf_avg_pool_2d_grad",
// "tf_avg_pool_3d",
// "tf_avg_pool_3d_grad",
"avg_pool_1d",
"avg_pool_1d_grad",
"avg_pool_2d",
"avg_pool_2d_grad",
"avg_pool_3d",
"avg_pool_3d_grad"
// "bias_add",
// "sigmoid_grad",
// "tanh",
// "tanh_grad",
// "sqrt",
// "sqrt_grad",
// "scalar_mul",
// "scalar_mul_by_tensor",
// "scalar_add",
// "scalar_div",
// "scalar_pow",
// "broadcast_add",
// "broadcast_sub",
// "broadcast_mul",
// "broadcast_div",
// "rms_norm",
// "rms_norm_grad",
// "rms_norm_param_grad",
// "dropout",
// "dropout_grad",
// "softmax",
// "softmax_grad",
// "log_softmax",
// "log_softmax_grad",
// "gelu",
// "gelu_grad",
// "fast_gelu",
// "fast_gelu_grad",
// "normalization",
// "normalization_grad",
// "normalization_add_relu",
// "normalization_add_relu_grad",
// "sparse_softmax_cross_entropy",
// "sparse_softmax_cross_entropy_grad",
// "nll",
// "nll_grad",
// "fused_tril_scale_softmax_mask_scale",
// "fused_tril_scale_softmax_mask_scale_grad",
// "fused_scale_mask_softmax_dropout",
// "fused_scale_mask_softmax_dropout_grad",
// "fused_scale_mask_softmax",
// "fused_scale_mask_softmax_grad",
// "fused_bias_add_scale_mask_softmax_dropout",
// "fused_bias_add_gelu",
// "fused_bias_add_gelu_grad",
// "fused_bias_add_mask_scale",
// "fused_fast_gelu_mul",
// "fused_fast_gelu_mul_grad",
// "acc",
// "reciprocal",
// "reciprocal_no_nan",
// "group_norm",
// "group_norm_param_grad",
// "group_norm_grad",
// "silu",
// "silu_grad",
// "fused_weighted_sum"
};
return gray_list;
}
const AMPList& AutoMixedPrecisionLists::ClearList() {
// TODO(niuchong): tuple_identity
static AMPList clear_list = {
// "broadcast_like",
// "gather",
// "gather_nd",
// "scatter_nd",
// "scatter_nd_like",
// "unsorted_segment_sum_like",
// "tf_max_pool_1d",
// "tf_max_pool_1d_grad",
// "tf_max_pool_2d",
// "tf_max_pool_2d_grad",
// "tf_max_pool_3d",
// "tf_max_pool_3d_grad",
"max_pool_1d",
"max_pool_1d_grad",
"max_pool_2d",
"max_pool_2d_grad",
"max_pool_3d",
"max_pool_3d_grad"
// "reshape",
// "reshape_like",
// "relu",
// "relu_grad",
// "transpose",
// "random_mask_like",
// "concat",
// "split_like",
// "pad",
// "same_padding",
// "same_padding_grad",
// "tril",
// "slice",
// "slice_grad",
// "fused_scale_tril",
// "identity",
// "squeeze",
// "embedding",
// "embedding_grad",
// "expand",
// "expand_dims",
// "cast_to_static_shape",
// "parallel_cast",
// "hierarchical_parallel_cast",
// "hierarchical_parallel_cast_like",
// "repeat",
// "unpack",
// "pack",
// "nvtx_start",
// "nvtx_end",
// "narrow",
// "narrow_grad",
// "ones_like",
// "pinned_identity",
// "to_contiguous",
// "copy",
// "upsample_nearest_2d"
};
return clear_list; return clear_list;
} }
......
...@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> { ...@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> {
OF_DEVICE_FUNC FusedFastGeluMulFunctor() {} OF_DEVICE_FUNC FusedFastGeluMulFunctor() {}
OF_DEVICE_FUNC half operator()(const half x, const half m) const { OF_DEVICE_FUNC half operator()(const half x, const half m) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const float tanh_in = const float tanh_in =
__half2float(__float2half_rn(alpha) * (x + __float2half_rn(beta) * x * x * x)); __half2float(__float2half_rn(alpha) * (x + __float2half_rn(beta) * x * x * x));
const float tanh_out = TanhApprox(tanh_in); const float tanh_out = TanhApprox(tanh_in);
...@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> { ...@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
} }
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* y, const half* x, const half* m) const { __device__ void Apply2(half* y, const half* x, const half* m) const {
const half2 x2 = *(reinterpret_cast<const half2*>(x)); const half2 x2 = *(reinterpret_cast<const half2*>(x));
const float2 tanh_in = __half22float2( const float2 tanh_in = __half22float2(
...@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> { ...@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> {
__device__ void operator()(half& x_diff, half& m_diff, const half& dy, const half& x, __device__ void operator()(half& x_diff, half& m_diff, const half& dy, const half& x,
const half& m) const { const half& m) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const half halpha = __float2half_rn(alpha); const half halpha = __float2half_rn(alpha);
const half hbeta = __float2half_rn(beta); const half hbeta = __float2half_rn(beta);
const half hone = __float2half_rn(1.0F); const half hone = __float2half_rn(1.0F);
...@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> { ...@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
} }
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* x_diff, half* m_diff, const half* dy, const half* x, __device__ void Apply2(half* x_diff, half* m_diff, const half* dy, const half* x,
const half* m) const { const half* m) const {
const half2 dy2 = *(reinterpret_cast<const half2*>(dy)); const half2 dy2 = *(reinterpret_cast<const half2*>(dy));
......
...@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16) ...@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
constexpr int kReduceBlockSize = 512; constexpr int kReduceBlockSize = 512;
constexpr int kBlockSize = 128; constexpr int kBlockSize = 128;
#ifdef WITH_ROCM
constexpr int kNumWaves = 64;
#else
constexpr int kNumWaves = 32; constexpr int kNumWaves = 32;
#endif
inline GPU(Error_t) GetReduceNumBlocks(int64_t n, int* num_blocks) { inline GPU(Error_t) GetReduceNumBlocks(int64_t n, int* num_blocks) {
int dev; int dev;
......
...@@ -28,7 +28,7 @@ limitations under the License. ...@@ -28,7 +28,7 @@ limitations under the License.
#elif defined(__HIPCC__) #elif defined(__HIPCC__)
#include <hip/hsa_detail/math_functions.h> #include <hip/math_functions.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__) #if defined(__HIP_DEVICE_COMPILE__)
......
...@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() { ...@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
CHECK_EQ_OR_RETURN(reserve_space_bits % split_num, 0); CHECK_EQ_OR_RETURN(reserve_space_bits % split_num, 0);
reserve_space_bits = reserve_space_bits / split_num; reserve_space_bits = reserve_space_bits / split_num;
} }
#ifdef WITH_ROCM
reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 64) / 64)}));
#else
reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 32) / 32)})); reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 32) / 32)}));
#endif
return Maybe<void>::Ok(); return Maybe<void>::Ok();
})(ctx); })(ctx);
} }
...@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() { ...@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,
user_op::TensorDesc* reserve_space) -> Maybe<void> { user_op::TensorDesc* reserve_space) -> Maybe<void> {
const auto& x_desc = ctx->InputTensorDesc("x", 0); const auto& x_desc = ctx->InputTensorDesc("x", 0);
#ifdef WITH_ROCM
reserve_space->set_shape(
Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 64) / 64)}));
#else
reserve_space->set_shape( reserve_space->set_shape(
Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)})); Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)}));
#endif
return Maybe<void>::Ok(); return Maybe<void>::Ok();
})(ctx); })(ctx);
} }
...@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() { ...@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
/* static */ Maybe<void> NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) { /* static */ Maybe<void> NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) {
return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,
user_op::TensorDesc* reserve_space) -> Maybe<void> { user_op::TensorDesc* reserve_space) -> Maybe<void> {
#ifdef WITH_ROCM
reserve_space->set_data_type(DataType::kInt64);
#else
reserve_space->set_data_type(DataType::kInt32); reserve_space->set_data_type(DataType::kInt32);
#endif
return Maybe<void>::Ok(); return Maybe<void>::Ok();
})(ctx); })(ctx);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment