Commit 6046d8fb authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

dtk23.04

parent a715222c
......@@ -334,16 +334,16 @@ elseif(WIN32)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow")
endif()
if (BUILD_ROCM)
# AMD compiler fails to compile these three files with '-O1/2/3'.
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# so '-O0' will override '-O1/2/3'.
set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
PROPERTIES COMPILE_OPTIONS "-O0")
endif()
# if (BUILD_ROCM)
# # AMD compiler fails to compile these three files with '-O1/2/3'.
# # The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# # so '-O0' will override '-O1/2/3'.
# set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
# ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
# ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
# # ${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
# PROPERTIES COMPILE_OPTIONS "-O0")
# endif()
if(BUILD_CUDA)
string(JOIN "," CUDA_REAL_ARCHS ${CUDA_REAL_ARCHS_LIST})
......
......@@ -186,7 +186,7 @@ if (BUILD_ROCM)
if (BUILD_ROCM_GRAPHS)
add_definitions(-DWITH_ROCM_GRAPHS)
endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__ -D__HIPCC__")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__ -D__HIPCC__")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --gpu-max-threads-per-block=1024")
......@@ -204,7 +204,7 @@ if (BUILD_ROCM)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -mcmodel=large")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -mcmodel=large")
list(APPEND oneflow_third_party_libs hip::device)
list(APPEND oneflow_third_party_libs hip::hipfft)
list(APPEND oneflow_third_party_libs roc::hipblas)
......
......@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n);
template<typename T>
OF_DEVICE_FUNC T DeviceMin(T a, T b) {
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return a < b ? a : b;
#else
return std::min(a, b);
......@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) {
template<typename T>
OF_DEVICE_FUNC T DeviceMax(T a, T b) {
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return a > b ? a : b;
#else
return std::max(a, b);
......
......@@ -54,7 +54,7 @@ struct CastCASImpl {
}
};
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__)) || defined(WITH_ROCM)
template<typename T>
struct CastCASImpl<T, unsigned short int> {
......
......@@ -308,6 +308,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
count_shared[wid] = warp_count;
}
__syncthreads();
#ifdef WITH_ROCM
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
warp_count = count_shared[lid];
} else {
warp_mean = static_cast<T>(0);
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
__syncthreads();
if (wid == 0) {
#else
if (wid == 0) {
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
......@@ -318,10 +333,7 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
__syncwarp();
#endif
T block_mean = 0;
T block_m2 = 0;
......@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 128;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32;
#endif
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
......@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
......@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
......@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 1024;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32;
#endif
int grid_dim_x;
{
GPU(Error_t) err =
......@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
constexpr int block_size = 128;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32;
#endif
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
......@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
......
......@@ -252,15 +252,21 @@ struct SetContext {
const int lane_age = ages[thread_ctx.lane_id];
const bool lane_hit = (lane_key == key && lane_age != 0);
#ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_hit);
const unsigned long long int hit_mask = __ballot(lane_hit);
if (hit_mask != 0) {
return __ffsll(static_cast<unsigned long long int>(hit_mask)) - 1;
} else {
return -1;
}
#else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_hit);
#endif
if (hit_mask != 0) {
return __ffs(static_cast<int>(hit_mask)) - 1;
} else {
return -1;
}
#endif
}
__device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx,
......@@ -277,15 +283,16 @@ struct SetContext {
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_key == key && lane_age != 0);
const unsigned long long int hit_mask = __ballot(lane_key == key && lane_age != 0);
#else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_key == key && lane_age != 0);
#endif
if (hit_mask != 0) {
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
#ifdef WITH_ROCM
insert_way = __ffsll(static_cast<unsigned long long int>(hit_mask)) - 1;
const int insert_way_age = __shfl(lane_age, insert_way);
#else
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
const int insert_way_age = __shfl_sync(kFullMask, lane_age, insert_way);
#endif
if (lane_age > insert_way_age) {
......@@ -301,12 +308,16 @@ __syncwarp();
}
if (insert_way == -1) {
#ifdef WITH_ROCM
const unsigned valid_mask = __ballot(lane_age != 0);
const unsigned long long int valid_mask = __ballot(lane_age != 0);
#else
const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0);
#endif
if (valid_mask != kFullMask) {
#ifdef WITH_ROCM
insert_way = __popcll(static_cast<unsigned long long int>(valid_mask));
#else
insert_way = __popc(static_cast<int>(valid_mask));
#endif
if (lane_age > 0) {
lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) {
......@@ -329,7 +340,8 @@ __syncwarp();
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM
const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1;
unsigned long long int valid_mask_tmp = __ballot(lane_age == 1);
const int insert_way = __ffsll(valid_mask_tmp) - 1;
#else
const int insert_way = __ffs(__ballot_sync(kFullMask, lane_age == 1)) - 1;
#endif
......@@ -594,7 +606,8 @@ __syncwarp();
warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;
warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;
#ifdef WITH_ROCM
const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0)));
unsigned long long int valid_mask_tmp = __ballot(lane_age != 0);
const int key_count = __popcll(valid_mask_tmp);
#else
const int key_count = __popc(__ballot_sync(kFullMask, lane_age != 0));
#endif
......
......@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/include/primitive//broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_SEQ_1, CUDA_PRIMITIVE_ALL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh"
namespace oneflow {
namespace ep {
namespace primitive {
namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_SEQ_1, CUDA_PRIMITIVE_ALL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
} // namespace oneflow
......@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
OF_DEVICE_FUNC half operator()(half src) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const float tanh_in =
__half2float(__float2half_rn(alpha) * (src + __float2half_rn(beta) * src * src * src));
const float tanh_out = unary_functor_internal::TanhApprox(tanh_in);
......@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* dst, const half* src) const {
const half2 src2 = *(reinterpret_cast<const half2*>(src));
const float2 tanh_in = __half22float2(__hmul2(
......
......@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() {
"reduce_sum_like",
"layer_norm_grad",
"layer_norm",
"layer_norm_param_grad"
};
return black_list;
}
"layer_norm_param_grad",
const AMPList& AutoMixedPrecisionLists::GrayList() {
static AMPList gray_list = {"add_n",
"add_n",
"tf_avg_pool_1d",
"tf_avg_pool_1d_grad",
"tf_avg_pool_2d",
"tf_avg_pool_2d_grad",
"tf_avg_pool_3d",
"tf_avg_pool_3d_grad",
"avg_pool_1d",
"avg_pool_1d_grad",
"avg_pool_2d",
"avg_pool_2d_grad",
"avg_pool_3d",
"avg_pool_3d_grad",
// "avg_pool_1d",
// "avg_pool_1d_grad",
// "avg_pool_2d",
// "avg_pool_2d_grad",
// "avg_pool_3d",
// "avg_pool_3d_grad",
"bias_add",
"sigmoid_grad",
"tanh",
......@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"group_norm_grad",
"silu",
"silu_grad",
"fused_weighted_sum"};
return gray_list;
}
"fused_weighted_sum",
const AMPList& AutoMixedPrecisionLists::ClearList() {
// TODO(niuchong): tuple_identity
static AMPList clear_list = {"broadcast_like",
"broadcast_like",
"gather",
"gather_nd",
"scatter_nd",
......@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"tf_max_pool_2d_grad",
"tf_max_pool_3d",
"tf_max_pool_3d_grad",
"max_pool_1d",
"max_pool_1d_grad",
"max_pool_2d",
"max_pool_2d_grad",
"max_pool_3d",
"max_pool_3d_grad",
// "max_pool_1d",
// "max_pool_1d_grad",
// "max_pool_2d",
// "max_pool_2d_grad",
// "max_pool_3d",
// "max_pool_3d_grad",
"reshape",
"reshape_like",
"relu",
......@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"pinned_identity",
"to_contiguous",
"copy",
"upsample_nearest_2d"};
"upsample_nearest_2d"
};
return black_list;
}
const AMPList& AutoMixedPrecisionLists::GrayList() {
static AMPList gray_list = {
// "add_n",
// "tf_avg_pool_1d",
// "tf_avg_pool_1d_grad",
// "tf_avg_pool_2d",
// "tf_avg_pool_2d_grad",
// "tf_avg_pool_3d",
// "tf_avg_pool_3d_grad",
"avg_pool_1d",
"avg_pool_1d_grad",
"avg_pool_2d",
"avg_pool_2d_grad",
"avg_pool_3d",
"avg_pool_3d_grad"
// "bias_add",
// "sigmoid_grad",
// "tanh",
// "tanh_grad",
// "sqrt",
// "sqrt_grad",
// "scalar_mul",
// "scalar_mul_by_tensor",
// "scalar_add",
// "scalar_div",
// "scalar_pow",
// "broadcast_add",
// "broadcast_sub",
// "broadcast_mul",
// "broadcast_div",
// "rms_norm",
// "rms_norm_grad",
// "rms_norm_param_grad",
// "dropout",
// "dropout_grad",
// "softmax",
// "softmax_grad",
// "log_softmax",
// "log_softmax_grad",
// "gelu",
// "gelu_grad",
// "fast_gelu",
// "fast_gelu_grad",
// "normalization",
// "normalization_grad",
// "normalization_add_relu",
// "normalization_add_relu_grad",
// "sparse_softmax_cross_entropy",
// "sparse_softmax_cross_entropy_grad",
// "nll",
// "nll_grad",
// "fused_tril_scale_softmax_mask_scale",
// "fused_tril_scale_softmax_mask_scale_grad",
// "fused_scale_mask_softmax_dropout",
// "fused_scale_mask_softmax_dropout_grad",
// "fused_scale_mask_softmax",
// "fused_scale_mask_softmax_grad",
// "fused_bias_add_scale_mask_softmax_dropout",
// "fused_bias_add_gelu",
// "fused_bias_add_gelu_grad",
// "fused_bias_add_mask_scale",
// "fused_fast_gelu_mul",
// "fused_fast_gelu_mul_grad",
// "acc",
// "reciprocal",
// "reciprocal_no_nan",
// "group_norm",
// "group_norm_param_grad",
// "group_norm_grad",
// "silu",
// "silu_grad",
// "fused_weighted_sum"
};
return gray_list;
}
const AMPList& AutoMixedPrecisionLists::ClearList() {
// TODO(niuchong): tuple_identity
static AMPList clear_list = {
// "broadcast_like",
// "gather",
// "gather_nd",
// "scatter_nd",
// "scatter_nd_like",
// "unsorted_segment_sum_like",
// "tf_max_pool_1d",
// "tf_max_pool_1d_grad",
// "tf_max_pool_2d",
// "tf_max_pool_2d_grad",
// "tf_max_pool_3d",
// "tf_max_pool_3d_grad",
"max_pool_1d",
"max_pool_1d_grad",
"max_pool_2d",
"max_pool_2d_grad",
"max_pool_3d",
"max_pool_3d_grad"
// "reshape",
// "reshape_like",
// "relu",
// "relu_grad",
// "transpose",
// "random_mask_like",
// "concat",
// "split_like",
// "pad",
// "same_padding",
// "same_padding_grad",
// "tril",
// "slice",
// "slice_grad",
// "fused_scale_tril",
// "identity",
// "squeeze",
// "embedding",
// "embedding_grad",
// "expand",
// "expand_dims",
// "cast_to_static_shape",
// "parallel_cast",
// "hierarchical_parallel_cast",
// "hierarchical_parallel_cast_like",
// "repeat",
// "unpack",
// "pack",
// "nvtx_start",
// "nvtx_end",
// "narrow",
// "narrow_grad",
// "ones_like",
// "pinned_identity",
// "to_contiguous",
// "copy",
// "upsample_nearest_2d"
};
return clear_list;
}
......
......@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> {
OF_DEVICE_FUNC FusedFastGeluMulFunctor() {}
OF_DEVICE_FUNC half operator()(const half x, const half m) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const float tanh_in =
__half2float(__float2half_rn(alpha) * (x + __float2half_rn(beta) * x * x * x));
const float tanh_out = TanhApprox(tanh_in);
......@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* y, const half* x, const half* m) const {
const half2 x2 = *(reinterpret_cast<const half2*>(x));
const float2 tanh_in = __half22float2(
......@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> {
__device__ void operator()(half& x_diff, half& m_diff, const half& dy, const half& x,
const half& m) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const half halpha = __float2half_rn(alpha);
const half hbeta = __float2half_rn(beta);
const half hone = __float2half_rn(1.0F);
......@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* x_diff, half* m_diff, const half* dy, const half* x,
const half* m) const {
const half2 dy2 = *(reinterpret_cast<const half2*>(dy));
......
......@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
constexpr int kReduceBlockSize = 512;
constexpr int kBlockSize = 128;
#ifdef WITH_ROCM
constexpr int kNumWaves = 64;
#else
constexpr int kNumWaves = 32;
#endif
inline GPU(Error_t) GetReduceNumBlocks(int64_t n, int* num_blocks) {
int dev;
......
......@@ -28,7 +28,7 @@ limitations under the License.
#elif defined(__HIPCC__)
#include <hip/hsa_detail/math_functions.h>
#include <hip/math_functions.h>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
......
......@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
CHECK_EQ_OR_RETURN(reserve_space_bits % split_num, 0);
reserve_space_bits = reserve_space_bits / split_num;
}
#ifdef WITH_ROCM
reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 64) / 64)}));
#else
reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 32) / 32)}));
#endif
return Maybe<void>::Ok();
})(ctx);
}
......@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,
user_op::TensorDesc* reserve_space) -> Maybe<void> {
const auto& x_desc = ctx->InputTensorDesc("x", 0);
#ifdef WITH_ROCM
reserve_space->set_shape(
Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 64) / 64)}));
#else
reserve_space->set_shape(
Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)}));
#endif
return Maybe<void>::Ok();
})(ctx);
}
......@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
/* static */ Maybe<void> NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) {
return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,
user_op::TensorDesc* reserve_space) -> Maybe<void> {
#ifdef WITH_ROCM
reserve_space->set_data_type(DataType::kInt64);
#else
reserve_space->set_data_type(DataType::kInt32);
#endif
return Maybe<void>::Ok();
})(ctx);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment