Commit 6046d8fb authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

dtk23.04

parent a715222c
...@@ -334,16 +334,16 @@ elseif(WIN32) ...@@ -334,16 +334,16 @@ elseif(WIN32)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow")
endif() endif()
if (BUILD_ROCM) # if (BUILD_ROCM)
# AMD compiler fails to compile these three files with '-O1/2/3'. # # AMD compiler fails to compile these three files with '-O1/2/3'.
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>, # # The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# so '-O0' will override '-O1/2/3'. # # so '-O0' will override '-O1/2/3'.
set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp # set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp # ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp # ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp # # ${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
PROPERTIES COMPILE_OPTIONS "-O0") # PROPERTIES COMPILE_OPTIONS "-O0")
endif() # endif()
if(BUILD_CUDA) if(BUILD_CUDA)
string(JOIN "," CUDA_REAL_ARCHS ${CUDA_REAL_ARCHS_LIST}) string(JOIN "," CUDA_REAL_ARCHS ${CUDA_REAL_ARCHS_LIST})
......
...@@ -186,7 +186,7 @@ if (BUILD_ROCM) ...@@ -186,7 +186,7 @@ if (BUILD_ROCM)
if (BUILD_ROCM_GRAPHS) if (BUILD_ROCM_GRAPHS)
add_definitions(-DWITH_ROCM_GRAPHS) add_definitions(-DWITH_ROCM_GRAPHS)
endif() endif()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__ -D__HIPCC__") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__ -D__HIPCC__")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__ -D__HIPCC__") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__ -D__HIPCC__")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --gpu-max-threads-per-block=1024") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --gpu-max-threads-per-block=1024")
...@@ -204,7 +204,7 @@ if (BUILD_ROCM) ...@@ -204,7 +204,7 @@ if (BUILD_ROCM)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -mcmodel=large") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -mcmodel=large")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -mcmodel=large") set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -mcmodel=large")
list(APPEND oneflow_third_party_libs hip::device) list(APPEND oneflow_third_party_libs hip::device)
list(APPEND oneflow_third_party_libs hip::hipfft) list(APPEND oneflow_third_party_libs hip::hipfft)
list(APPEND oneflow_third_party_libs roc::hipblas) list(APPEND oneflow_third_party_libs roc::hipblas)
......
...@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n); ...@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n);
template<typename T> template<typename T>
OF_DEVICE_FUNC T DeviceMin(T a, T b) { OF_DEVICE_FUNC T DeviceMin(T a, T b) {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return a < b ? a : b; return a < b ? a : b;
#else #else
return std::min(a, b); return std::min(a, b);
...@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) { ...@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) {
template<typename T> template<typename T>
OF_DEVICE_FUNC T DeviceMax(T a, T b) { OF_DEVICE_FUNC T DeviceMax(T a, T b) {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return a > b ? a : b; return a > b ? a : b;
#else #else
return std::max(a, b); return std::max(a, b);
......
...@@ -54,7 +54,7 @@ struct CastCASImpl { ...@@ -54,7 +54,7 @@ struct CastCASImpl {
} }
}; };
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__)) #if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__)) || defined(WITH_ROCM)
template<typename T> template<typename T>
struct CastCASImpl<T, unsigned short int> { struct CastCASImpl<T, unsigned short int> {
......
...@@ -308,6 +308,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t ...@@ -308,6 +308,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
count_shared[wid] = warp_count; count_shared[wid] = warp_count;
} }
__syncthreads(); __syncthreads();
#ifdef WITH_ROCM
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
warp_count = count_shared[lid];
} else {
warp_mean = static_cast<T>(0);
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
__syncthreads();
if (wid == 0) {
#else
if (wid == 0) { if (wid == 0) {
if (threadIdx.x < blockDim.x / kWarpSize) { if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid]; warp_mean = mean_shared[lid];
...@@ -318,10 +333,7 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t ...@@ -318,10 +333,7 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
warp_m2 = static_cast<T>(0); warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0); warp_count = static_cast<T>(0);
} }
#ifdef WITH_ROCM __syncwarp();
__syncthreads();
#else
__syncwarp();
#endif #endif
T block_mean = 0; T block_mean = 0;
T block_m2 = 0; T block_m2 = 0;
...@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO ...@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO
const double epsilon, ComputeType* mean, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) { ComputeType* inv_variance) {
constexpr int block_size = 128; constexpr int block_size = 128;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32; constexpr int waves = 32;
#endif
static_assert(block_size % thread_group_width == 0, ""); static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width; constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block); dim3 block_dim(thread_group_width, thread_groups_per_block);
...@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar ...@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \ stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \ } \
} }
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32) DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF #undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \ #define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \ else if (cols <= (max_col)*kWarpSize) { \
...@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar ...@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \ stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \ } \
} }
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32) DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF #undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \ #define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \ else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
...@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD ...@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD
const double epsilon, ComputeType* mean, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) { ComputeType* inv_variance) {
constexpr int block_size = 1024; constexpr int block_size = 1024;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32; constexpr int waves = 32;
#endif
int grid_dim_x; int grid_dim_x;
{ {
GPU(Error_t) err = GPU(Error_t) err =
...@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa ...@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa
const ComputeType* inv_variance, const int64_t rows, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) { const int64_t cols) {
constexpr int block_size = 128; constexpr int block_size = 128;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32; constexpr int waves = 32;
#endif
static_assert(block_size % thread_group_width == 0, ""); static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width; constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block); dim3 block_dim(thread_group_width, thread_groups_per_block);
...@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra ...@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \ stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \ } \
} }
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16) DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32) DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF #undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \ #define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \ else if (cols <= (max_col)*kWarpSize) { \
......
...@@ -252,15 +252,21 @@ struct SetContext { ...@@ -252,15 +252,21 @@ struct SetContext {
const int lane_age = ages[thread_ctx.lane_id]; const int lane_age = ages[thread_ctx.lane_id];
const bool lane_hit = (lane_key == key && lane_age != 0); const bool lane_hit = (lane_key == key && lane_age != 0);
#ifdef WITH_ROCM #ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_hit); const unsigned long long int hit_mask = __ballot(lane_hit);
if (hit_mask != 0) {
return __ffsll(static_cast<unsigned long long int>(hit_mask)) - 1;
} else {
return -1;
}
#else #else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_hit); const unsigned hit_mask = __ballot_sync(kFullMask, lane_hit);
#endif
if (hit_mask != 0) { if (hit_mask != 0) {
return __ffs(static_cast<int>(hit_mask)) - 1; return __ffs(static_cast<int>(hit_mask)) - 1;
} else { } else {
return -1; return -1;
} }
#endif
} }
__device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx, __device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx,
...@@ -277,15 +283,16 @@ struct SetContext { ...@@ -277,15 +283,16 @@ struct SetContext {
const Key lane_key = keys[thread_ctx.lane_id]; const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id]; int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM #ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_key == key && lane_age != 0); const unsigned long long int hit_mask = __ballot(lane_key == key && lane_age != 0);
#else #else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_key == key && lane_age != 0); const unsigned hit_mask = __ballot_sync(kFullMask, lane_key == key && lane_age != 0);
#endif #endif
if (hit_mask != 0) { if (hit_mask != 0) {
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
#ifdef WITH_ROCM #ifdef WITH_ROCM
insert_way = __ffsll(static_cast<unsigned long long int>(hit_mask)) - 1;
const int insert_way_age = __shfl(lane_age, insert_way); const int insert_way_age = __shfl(lane_age, insert_way);
#else #else
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
const int insert_way_age = __shfl_sync(kFullMask, lane_age, insert_way); const int insert_way_age = __shfl_sync(kFullMask, lane_age, insert_way);
#endif #endif
if (lane_age > insert_way_age) { if (lane_age > insert_way_age) {
...@@ -301,12 +308,16 @@ __syncwarp(); ...@@ -301,12 +308,16 @@ __syncwarp();
} }
if (insert_way == -1) { if (insert_way == -1) {
#ifdef WITH_ROCM #ifdef WITH_ROCM
const unsigned valid_mask = __ballot(lane_age != 0); const unsigned long long int valid_mask = __ballot(lane_age != 0);
#else #else
const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0); const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0);
#endif #endif
if (valid_mask != kFullMask) { if (valid_mask != kFullMask) {
#ifdef WITH_ROCM
insert_way = __popcll(static_cast<unsigned long long int>(valid_mask));
#else
insert_way = __popc(static_cast<int>(valid_mask)); insert_way = __popc(static_cast<int>(valid_mask));
#endif
if (lane_age > 0) { if (lane_age > 0) {
lane_age -= 1; lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) { } else if (thread_ctx.lane_id == insert_way) {
...@@ -329,7 +340,8 @@ __syncwarp(); ...@@ -329,7 +340,8 @@ __syncwarp();
const Key lane_key = keys[thread_ctx.lane_id]; const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id]; int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM #ifdef WITH_ROCM
const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1; unsigned long long int valid_mask_tmp = __ballot(lane_age == 1);
const int insert_way = __ffsll(valid_mask_tmp) - 1;
#else #else
const int insert_way = __ffs(__ballot_sync(kFullMask, lane_age == 1)) - 1; const int insert_way = __ffs(__ballot_sync(kFullMask, lane_age == 1)) - 1;
#endif #endif
...@@ -594,7 +606,8 @@ __syncwarp(); ...@@ -594,7 +606,8 @@ __syncwarp();
warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key; warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;
warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age; warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;
#ifdef WITH_ROCM #ifdef WITH_ROCM
const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0))); unsigned long long int valid_mask_tmp = __ballot(lane_age != 0);
const int key_count = __popcll(valid_mask_tmp);
#else #else
const int key_count = __popc(__ballot_sync(kFullMask, lane_age != 0)); const int key_count = __popc(__ballot_sync(kFullMask, lane_age != 0));
#endif #endif
......
...@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/include/primitive//broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" #include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h" #include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/cuda/cuda_stream.h"
......
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" #include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh"
namespace oneflow { namespace oneflow {
namespace ep { namespace ep {
namespace primitive { namespace primitive {
namespace broadcast_elementwise_binary { namespace broadcast_elementwise_binary {
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ #define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \ template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \ binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1); Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_OP_SEQ_1, CUDA_PRIMITIVE_ALL_TYPE_SEQ); BINARY_MATH_OP_SEQ_1, CUDA_PRIMITIVE_ALL_TYPE_SEQ);
} // namespace broadcast_elementwise_binary } // namespace broadcast_elementwise_binary
} // namespace primitive } // namespace primitive
} // namespace ep } // namespace ep
} // namespace oneflow } // namespace oneflow
...@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> { ...@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
OF_DEVICE_FUNC half operator()(half src) const { OF_DEVICE_FUNC half operator()(half src) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const float tanh_in = const float tanh_in =
__half2float(__float2half_rn(alpha) * (src + __float2half_rn(beta) * src * src * src)); __half2float(__float2half_rn(alpha) * (src + __float2half_rn(beta) * src * src * src));
const float tanh_out = unary_functor_internal::TanhApprox(tanh_in); const float tanh_out = unary_functor_internal::TanhApprox(tanh_in);
...@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> { ...@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
} }
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* dst, const half* src) const { __device__ void Apply2(half* dst, const half* src) const {
const half2 src2 = *(reinterpret_cast<const half2*>(src)); const half2 src2 = *(reinterpret_cast<const half2*>(src));
const float2 tanh_in = __half22float2(__hmul2( const float2 tanh_in = __half22float2(__hmul2(
......
...@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() { ...@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() {
"reduce_sum_like", "reduce_sum_like",
"layer_norm_grad", "layer_norm_grad",
"layer_norm", "layer_norm",
"layer_norm_param_grad" "layer_norm_param_grad",
};
return black_list;
}
const AMPList& AutoMixedPrecisionLists::GrayList() { "add_n",
static AMPList gray_list = {"add_n",
"tf_avg_pool_1d", "tf_avg_pool_1d",
"tf_avg_pool_1d_grad", "tf_avg_pool_1d_grad",
"tf_avg_pool_2d", "tf_avg_pool_2d",
"tf_avg_pool_2d_grad", "tf_avg_pool_2d_grad",
"tf_avg_pool_3d", "tf_avg_pool_3d",
"tf_avg_pool_3d_grad", "tf_avg_pool_3d_grad",
"avg_pool_1d", // "avg_pool_1d",
"avg_pool_1d_grad", // "avg_pool_1d_grad",
"avg_pool_2d", // "avg_pool_2d",
"avg_pool_2d_grad", // "avg_pool_2d_grad",
"avg_pool_3d", // "avg_pool_3d",
"avg_pool_3d_grad", // "avg_pool_3d_grad",
"bias_add", "bias_add",
"sigmoid_grad", "sigmoid_grad",
"tanh", "tanh",
...@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() { ...@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"group_norm_grad", "group_norm_grad",
"silu", "silu",
"silu_grad", "silu_grad",
"fused_weighted_sum"}; "fused_weighted_sum",
return gray_list;
}
const AMPList& AutoMixedPrecisionLists::ClearList() { "broadcast_like",
// TODO(niuchong): tuple_identity
static AMPList clear_list = {"broadcast_like",
"gather", "gather",
"gather_nd", "gather_nd",
"scatter_nd", "scatter_nd",
...@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() { ...@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"tf_max_pool_2d_grad", "tf_max_pool_2d_grad",
"tf_max_pool_3d", "tf_max_pool_3d",
"tf_max_pool_3d_grad", "tf_max_pool_3d_grad",
"max_pool_1d", // "max_pool_1d",
"max_pool_1d_grad", // "max_pool_1d_grad",
"max_pool_2d", // "max_pool_2d",
"max_pool_2d_grad", // "max_pool_2d_grad",
"max_pool_3d", // "max_pool_3d",
"max_pool_3d_grad", // "max_pool_3d_grad",
"reshape", "reshape",
"reshape_like", "reshape_like",
"relu", "relu",
...@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() { ...@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"pinned_identity", "pinned_identity",
"to_contiguous", "to_contiguous",
"copy", "copy",
"upsample_nearest_2d"}; "upsample_nearest_2d"
};
return black_list;
}
const AMPList& AutoMixedPrecisionLists::GrayList() {
static AMPList gray_list = {
// "add_n",
// "tf_avg_pool_1d",
// "tf_avg_pool_1d_grad",
// "tf_avg_pool_2d",
// "tf_avg_pool_2d_grad",
// "tf_avg_pool_3d",
// "tf_avg_pool_3d_grad",
"avg_pool_1d",
"avg_pool_1d_grad",
"avg_pool_2d",
"avg_pool_2d_grad",
"avg_pool_3d",
"avg_pool_3d_grad"
// "bias_add",
// "sigmoid_grad",
// "tanh",
// "tanh_grad",
// "sqrt",
// "sqrt_grad",
// "scalar_mul",
// "scalar_mul_by_tensor",
// "scalar_add",
// "scalar_div",
// "scalar_pow",
// "broadcast_add",
// "broadcast_sub",
// "broadcast_mul",
// "broadcast_div",
// "rms_norm",
// "rms_norm_grad",
// "rms_norm_param_grad",
// "dropout",
// "dropout_grad",
// "softmax",
// "softmax_grad",
// "log_softmax",
// "log_softmax_grad",
// "gelu",
// "gelu_grad",
// "fast_gelu",
// "fast_gelu_grad",
// "normalization",
// "normalization_grad",
// "normalization_add_relu",
// "normalization_add_relu_grad",
// "sparse_softmax_cross_entropy",
// "sparse_softmax_cross_entropy_grad",
// "nll",
// "nll_grad",
// "fused_tril_scale_softmax_mask_scale",
// "fused_tril_scale_softmax_mask_scale_grad",
// "fused_scale_mask_softmax_dropout",
// "fused_scale_mask_softmax_dropout_grad",
// "fused_scale_mask_softmax",
// "fused_scale_mask_softmax_grad",
// "fused_bias_add_scale_mask_softmax_dropout",
// "fused_bias_add_gelu",
// "fused_bias_add_gelu_grad",
// "fused_bias_add_mask_scale",
// "fused_fast_gelu_mul",
// "fused_fast_gelu_mul_grad",
// "acc",
// "reciprocal",
// "reciprocal_no_nan",
// "group_norm",
// "group_norm_param_grad",
// "group_norm_grad",
// "silu",
// "silu_grad",
// "fused_weighted_sum"
};
return gray_list;
}
const AMPList& AutoMixedPrecisionLists::ClearList() {
// TODO(niuchong): tuple_identity
static AMPList clear_list = {
// "broadcast_like",
// "gather",
// "gather_nd",
// "scatter_nd",
// "scatter_nd_like",
// "unsorted_segment_sum_like",
// "tf_max_pool_1d",
// "tf_max_pool_1d_grad",
// "tf_max_pool_2d",
// "tf_max_pool_2d_grad",
// "tf_max_pool_3d",
// "tf_max_pool_3d_grad",
"max_pool_1d",
"max_pool_1d_grad",
"max_pool_2d",
"max_pool_2d_grad",
"max_pool_3d",
"max_pool_3d_grad"
// "reshape",
// "reshape_like",
// "relu",
// "relu_grad",
// "transpose",
// "random_mask_like",
// "concat",
// "split_like",
// "pad",
// "same_padding",
// "same_padding_grad",
// "tril",
// "slice",
// "slice_grad",
// "fused_scale_tril",
// "identity",
// "squeeze",
// "embedding",
// "embedding_grad",
// "expand",
// "expand_dims",
// "cast_to_static_shape",
// "parallel_cast",
// "hierarchical_parallel_cast",
// "hierarchical_parallel_cast_like",
// "repeat",
// "unpack",
// "pack",
// "nvtx_start",
// "nvtx_end",
// "narrow",
// "narrow_grad",
// "ones_like",
// "pinned_identity",
// "to_contiguous",
// "copy",
// "upsample_nearest_2d"
};
return clear_list; return clear_list;
} }
......
...@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> { ...@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> {
OF_DEVICE_FUNC FusedFastGeluMulFunctor() {} OF_DEVICE_FUNC FusedFastGeluMulFunctor() {}
OF_DEVICE_FUNC half operator()(const half x, const half m) const { OF_DEVICE_FUNC half operator()(const half x, const half m) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const float tanh_in = const float tanh_in =
__half2float(__float2half_rn(alpha) * (x + __float2half_rn(beta) * x * x * x)); __half2float(__float2half_rn(alpha) * (x + __float2half_rn(beta) * x * x * x));
const float tanh_out = TanhApprox(tanh_in); const float tanh_out = TanhApprox(tanh_in);
...@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> { ...@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
} }
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* y, const half* x, const half* m) const { __device__ void Apply2(half* y, const half* x, const half* m) const {
const half2 x2 = *(reinterpret_cast<const half2*>(x)); const half2 x2 = *(reinterpret_cast<const half2*>(x));
const float2 tanh_in = __half22float2( const float2 tanh_in = __half22float2(
...@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> { ...@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> {
__device__ void operator()(half& x_diff, half& m_diff, const half& dy, const half& x, __device__ void operator()(half& x_diff, half& m_diff, const half& dy, const half& x,
const half& m) const { const half& m) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const half halpha = __float2half_rn(alpha); const half halpha = __float2half_rn(alpha);
const half hbeta = __float2half_rn(beta); const half hbeta = __float2half_rn(beta);
const half hone = __float2half_rn(1.0F); const half hone = __float2half_rn(1.0F);
...@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> { ...@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
} }
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) #if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* x_diff, half* m_diff, const half* dy, const half* x, __device__ void Apply2(half* x_diff, half* m_diff, const half* dy, const half* x,
const half* m) const { const half* m) const {
const half2 dy2 = *(reinterpret_cast<const half2*>(dy)); const half2 dy2 = *(reinterpret_cast<const half2*>(dy));
......
...@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16) ...@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
constexpr int kReduceBlockSize = 512; constexpr int kReduceBlockSize = 512;
constexpr int kBlockSize = 128; constexpr int kBlockSize = 128;
#ifdef WITH_ROCM
constexpr int kNumWaves = 64;
#else
constexpr int kNumWaves = 32; constexpr int kNumWaves = 32;
#endif
inline GPU(Error_t) GetReduceNumBlocks(int64_t n, int* num_blocks) { inline GPU(Error_t) GetReduceNumBlocks(int64_t n, int* num_blocks) {
int dev; int dev;
......
...@@ -27,9 +27,12 @@ limitations under the License. ...@@ -27,9 +27,12 @@ limitations under the License.
#include <cuda_bf16.h> #include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000 #endif // CUDA_VERSION >= 11000
#ifdef WITH_CUDA #ifdef WITH_ROCM
#include <thrust/pair.h> #include <hipcub/hipcub.hpp>
#else
#include <cub/cub.cuh> #include <cub/cub.cuh>
#endif
namespace oneflow { namespace oneflow {
namespace { namespace {
...@@ -143,18 +146,31 @@ __inline__ __device__ T WarpReduce(T val) { ...@@ -143,18 +146,31 @@ __inline__ __device__ T WarpReduce(T val) {
return val; return val;
} }
constexpr int tile_size = 32;
constexpr int num_per_block = 4; constexpr int num_per_block = 4;
#ifdef WITH_ROCM
constexpr int tile_size = 64;
constexpr int block_dim_x = 64;
constexpr int block_dim_y = 64 / num_per_block;
#else
constexpr int tile_size = 32;
constexpr int block_dim_x = 32; constexpr int block_dim_x = 32;
constexpr int block_dim_y = 32 / num_per_block; constexpr int block_dim_y = 32 / num_per_block;
#endif
template<typename T, typename ComputeType> template<typename T, typename ComputeType>
__global__ void LayerNormParamGrad(int rows, int cols, const T* __restrict__ dy, __global__ void LayerNormParamGrad(int rows, int cols, const T* __restrict__ dy,
const T* __restrict__ x, const ComputeType* __restrict__ mean, const T* __restrict__ x, const ComputeType* __restrict__ mean,
const ComputeType* __restrict__ inv_var, const ComputeType* __restrict__ inv_var,
T* __restrict__ tmp_gamma_diff, T* __restrict__ tmp_beta_diff) { T* __restrict__ tmp_gamma_diff, T* __restrict__ tmp_beta_diff) {
#ifdef WITH_ROCM
__shared__ ComputeType dgamma[64][65];
__shared__ ComputeType dbeta[64][65];
#else
__shared__ ComputeType dgamma[32][33]; __shared__ ComputeType dgamma[32][33];
__shared__ ComputeType dbeta[32][33]; __shared__ ComputeType dbeta[32][33];
#endif
ComputeType dgamma_sum[num_per_block]; ComputeType dgamma_sum[num_per_block];
ComputeType dbeta_sum[num_per_block]; ComputeType dbeta_sum[num_per_block];
#pragma unroll #pragma unroll
...@@ -210,13 +226,13 @@ int GetGirdDimY(const int64_t num_instances, const int64_t norm_size) { ...@@ -210,13 +226,13 @@ int GetGirdDimY(const int64_t num_instances, const int64_t norm_size) {
const int max_grid_dim_y = (num_instances + tile_size - 1) / tile_size; const int max_grid_dim_y = (num_instances + tile_size - 1) / tile_size;
const int block_size = block_dim_x * block_dim_y; const int block_size = block_dim_x * block_dim_y;
int max_active_blocks = 0; int max_active_blocks = 0;
OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( OF_CUDA_CHECK(GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks, LayerNormParamGrad<T, ComputeType>, block_size, 0)); &max_active_blocks, LayerNormParamGrad<T, ComputeType>, block_size, 0));
int waves = 1; int waves = 1;
int dev; int dev;
OF_CUDA_CHECK(cudaGetDevice(&dev)); OF_CUDA_CHECK(GPU(GetDevice)(&dev));
int sm_count; int sm_count;
OF_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev)); OF_CUDA_CHECK(GPU(DeviceGetAttribute)(&sm_count, GPUMultiProcessorCount, dev));
int num_blocks = max_active_blocks * sm_count * waves; int num_blocks = max_active_blocks * sm_count * waves;
int grid_dim_y = std::min(max_grid_dim_y, static_cast<int>(num_blocks / grid_dim_x)); int grid_dim_y = std::min(max_grid_dim_y, static_cast<int>(num_blocks / grid_dim_x));
return std::max(grid_dim_y, 1); return std::max(grid_dim_y, 1);
...@@ -420,6 +436,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel, ...@@ -420,6 +436,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
const int64_t norm_size = x->shape_view().elem_cnt() / num_instances; const int64_t norm_size = x->shape_view().elem_cnt() / num_instances;
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const DataType data_type = dy->data_type(); const DataType data_type = dy->data_type();
const int grid_dim_x = (norm_size + tile_size - 1) / tile_size; const int grid_dim_x = (norm_size + tile_size - 1) / tile_size;
const int grid_dim_y = GetGirdDimY<T>(num_instances, norm_size); const int grid_dim_y = GetGirdDimY<T>(num_instances, norm_size);
const size_t tmp_gamma_diff_size = grid_dim_y * norm_size * sizeof(T); const size_t tmp_gamma_diff_size = grid_dim_y * norm_size * sizeof(T);
...@@ -428,7 +445,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel, ...@@ -428,7 +445,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
T* reduce_buf_ptr = T* reduce_buf_ptr =
reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + 2 * tmp_gamma_diff_size); reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + 2 * tmp_gamma_diff_size);
using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type; using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;
LayerNormParamGrad<T, ComputeType><<<dim3(grid_dim_x, grid_dim_y), dim3(32, 32 / num_per_block), LayerNormParamGrad<T, ComputeType><<<dim3(grid_dim_x, grid_dim_y), dim3(block_dim_x, block_dim_y),
0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>( 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(
num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(), num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(),
inv_variance->dptr<ComputeType>(), tmp_gamma_diff_ptr, tmp_beta_diff_ptr); inv_variance->dptr<ComputeType>(), tmp_gamma_diff_ptr, tmp_beta_diff_ptr);
...@@ -476,651 +493,13 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel, ...@@ -476,651 +493,13 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
}); });
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(float) REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(float)
#ifdef WITH_CUDA
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(double) REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(double)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(half)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(nv_bfloat16)
#endif
} // namespace oneflow
#endif #endif
#ifdef WITH_ROCM
#include <hipcub/hipcub.hpp>
#include <thrust/pair.h>
template <typename T, bool is_cuda>
struct AccumulateType { };
#if defined(__HIPCC__)
template <> struct AccumulateType<half, true> { using type = float; };
#endif
template <> struct AccumulateType<float, true> { using type = float; };
template <> struct AccumulateType<double, true> { using type = double; };
template <> struct AccumulateType<int8_t, true> { using type = int64_t; };
template <> struct AccumulateType<uint8_t, true> { using type = int64_t; };
template <> struct AccumulateType<char, true> { using type = int64_t; };
template <> struct AccumulateType<int16_t, true> { using type = int64_t; };
template <> struct AccumulateType<int32_t, true> { using type = int64_t; };
template <> struct AccumulateType<int64_t, true> { using type = int64_t; };
template <> struct AccumulateType<bool, true> {using type = bool; };
template <> struct AccumulateType<float, false> { using type = double; };
template <> struct AccumulateType<double, false> { using type = double; };
template <> struct AccumulateType<int8_t, false> { using type = int64_t; };
template <> struct AccumulateType<uint8_t, false> { using type = int64_t; };
template <> struct AccumulateType<char, false> { using type = int64_t; };
template <> struct AccumulateType<int16_t, false> { using type = int64_t; };
template <> struct AccumulateType<int32_t, false> { using type = int64_t; };
template <> struct AccumulateType<int64_t, false> { using type = int64_t; };
template <> struct AccumulateType<bool, false> {using type = bool; };
template<typename T, bool is_cuda>
using acc_type = typename AccumulateType<T, is_cuda>::type;
#define C10_HOST_DEVICE __host__ __device__
#define C10_DEVICE __device__
#define C10_HOST __host__
#define C10_WARP_SIZE 64
#define VEC 4
typedef int64_t IndexType ;
constexpr int BlockReduceNumThreads=512;
constexpr int NumThreads = 256;
constexpr int ColwiseReduceTileSize = 32;
template <typename scalar_t, typename index_t, typename combine_t>
struct WelfordData {
scalar_t mean;
scalar_t m2;
index_t n;
combine_t nf;
C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
C10_HOST_DEVICE WelfordData(
scalar_t mean,
scalar_t m2,
index_t n,
combine_t nf)
: mean(mean), m2(m2), n(n), nf(nf) {}
};
template <typename scalar_t, typename acc_scalar_t, typename index_t, typename combine_t, typename res_t>
struct WelfordOps {
public:
using acc_t = WelfordData<acc_scalar_t, index_t, combine_t>;
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data) const {
acc_scalar_t delta = data - acc.mean;
// using acc.nf(combine_t) here, as acc.n(index_t) would still be converted
// accumulation in reduce is done through index_T
acc_scalar_t new_mean = acc.mean + delta / (acc.nf + 1);
acc_scalar_t new_delta = data - new_mean;
return {
new_mean,
acc.m2 + delta * new_delta,
acc.n + 1,
combine_t(acc.n + 1), // accumulate for combine_t uses index_t
};
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
if (a.nf == 0) {
return b;
}
if (b.nf == 0) {
return a;
}
acc_scalar_t delta = b.mean - a.mean;
combine_t new_count = a.nf + b.nf;
acc_scalar_t nb_over_n = b.nf / new_count;
return {
a.mean + delta * nb_over_n,
a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
// setting acc.n as -1 since acc.n might not be able to represent the count
// correctly within its range, setting it to -1 to avoid confusion
-1,
new_count
};
}
inline C10_DEVICE res_t project(acc_t acc) const {
return res_t(acc.m2 / acc.nf, static_cast<scalar_t>(acc.mean));
}
inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
return {
__shfl_down(acc.mean, offset)
, __shfl_down(acc.m2, offset)
, __shfl_down(acc.n, offset)
, __shfl_down(acc.nf, offset)
};
}
};
template <int max=32,typename T, class ReduceOp>
__inline__ __device__ T WarpReduce(T val,const ReduceOp& op) {
#pragma unroll
for (int offset = max; offset > 0; offset >>= 1) {
val = op.combine(val, op.warp_shfl_down(val, offset));
}
return val;
}
template <typename T, class ReduceOp>
__inline__ __device__ T
BlockReduce(T val, const ReduceOp& op, T* shared) {
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
val = WarpReduce(val, op);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0) {
val= shared[lid];
val = WarpReduce<4>(val,op);
}
return val;
}
template <int max=32,typename T>
__inline__ __device__ T WarpReduceSum(T val) {
#pragma unroll
for (int offset = max; offset > 0; offset >>= 1) {
val += __shfl_down(val, offset);
}
return val;
}
template <typename T>
__inline__ __device__ T BlockReduceSum(T val, T* shared) {
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
val = WarpReduceSum(val);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0) {
val= shared[lid];
val = WarpReduceSum<4>(val);
}
return val;
}
template <typename scalar_t>
__global__ void layernorm_forward_kernel(const scalar_t* input,scalar_t* ret,acc_type<scalar_t, true>* mean,acc_type<scalar_t, true>* rstd,
const scalar_t* gamma,const scalar_t* beta,IndexType cols,double eps)
{
//dropout do nothing in val mode
IndexType i=blockIdx.x;
// add + layernorm get mean and rstd
using T_ACC = acc_type<scalar_t, true>;
using WelfordType = WelfordData<T_ACC, IndexType, T_ACC>;
using WelfordOp = WelfordOps<T_ACC, T_ACC, IndexType, T_ACC, thrust::pair<T_ACC, T_ACC>>;
__shared__ typename std::aligned_storage<sizeof(WelfordType), alignof(WelfordType)>::type val_shared[BlockReduceNumThreads/C10_WARP_SIZE];
WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared);
WelfordOp welford_op;
WelfordType val;
#pragma unroll
for (IndexType j = threadIdx.x; j < cols; j += blockDim.x) {
IndexType index = i * cols + j;
val = welford_op.reduce(val, static_cast<T_ACC>(input[index]));
}
val = BlockReduce(val,welford_op,val_shared_ptr);
__shared__ T_ACC s_mean;
__shared__ T_ACC s_rstd;
if (threadIdx.x == 0) {
thrust::tie(s_rstd, s_mean) = welford_op.project(val);
mean[i] = s_mean;
s_rstd=rsqrt(s_rstd + static_cast<T_ACC>(eps));
rstd[i] = s_rstd;
}
__syncthreads();
//layernorm (x-mean)*rstd*gamma+beta
#pragma unroll
for (IndexType j = threadIdx.x; j < cols; j += blockDim.x) {
IndexType index = i * cols + j;
ret[index] = static_cast<scalar_t>((static_cast<T_ACC>(input[index]) - s_mean)*s_rstd * (gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[j]))
+ (beta == nullptr ? T_ACC(0) : static_cast<T_ACC>(beta[j])));
}
}
template <typename T>
void LayerNormKernelImplInternal(
oneflow::ep::Stream* stream,
const T* X,
const T* gamma,
const T* beta,
int64_t M,
int64_t N,
double eps,
T* Y,
acc_type<T, true>* mean,
acc_type<T, true>* rstd) {
using T_ACC = acc_type<T, true>;
const T* X_data = X;
const T* gamma_data = gamma;
const T* beta_data = beta;
T* Y_data = Y;
T_ACC* mean_data = mean;
T_ACC* rstd_data = rstd;
hipStream_t cuda_stream = stream->As<oneflow::ep::CudaStream>()->cuda_stream();
layernorm_forward_kernel<T><<<M, BlockReduceNumThreads, 0, cuda_stream>>>(
X_data,Y_data,mean_data,rstd_data,gamma_data,beta_data,N,eps);
}
template <typename scalar_t>
__global__ void GammaBetaBackwardSimple(IndexType M,IndexType N,const scalar_t* dY,const scalar_t* X,const acc_type<scalar_t, true>* mean,
const acc_type<scalar_t, true>* rstd,scalar_t* dg,scalar_t* db)
{
using T_ACC = acc_type<scalar_t, true>;
const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
if (j < N) {
T_ACC sum1 = 0;
T_ACC sum2 = 0;
for (int64_t i = 0; i < M; ++i) {
const int64_t index = i * N + j;
sum1 += dg == nullptr ? T_ACC(0)
: static_cast<T_ACC>(dY[index]) *
(static_cast<T_ACC>(X[index]) - static_cast<T_ACC>(mean[i])) *
static_cast<T_ACC>(rstd[i]);
sum2 += db == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index]);
}
if (dg != nullptr) {
dg[j] = static_cast<scalar_t>(sum1);
}
if (db != nullptr) {
db[j] = static_cast<scalar_t>(sum2);
}
}
}
template <typename scalar_t>
__global__ void GammaBetaBackward(IndexType M,IndexType N,const scalar_t* dY,const scalar_t* X,const acc_type<scalar_t, true>* mean,
const acc_type<scalar_t, true>* rstd,scalar_t* dg,scalar_t* db)
{
using T_ACC = acc_type<scalar_t, true>;
__shared__ T_ACC g_shared[ColwiseReduceTileSize][ColwiseReduceTileSize + 1];
__shared__ T_ACC b_shared[ColwiseReduceTileSize][ColwiseReduceTileSize + 1];
const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
T_ACC dg_sum1 = 0;
T_ACC dg_sum2 = 0;
T_ACC db_sum1 = 0;
T_ACC db_sum2 = 0;
if (j < N) {
for (int64_t i = threadIdx.y; i < M; i += blockDim.y * 2) {
const int64_t i1 = i;
const int64_t i2 = i + blockDim.y;
const int64_t index1 = i1 * N + j;
const int64_t index2 = i2 * N + j;
dg_sum1 += dg == nullptr ? T_ACC(0)
: static_cast<T_ACC>(dY[index1]) *
(static_cast<T_ACC>(X[index1]) - static_cast<T_ACC>(mean[i1])) *
static_cast<T_ACC>(rstd[i1]);
db_sum1 += db == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index1]);
if (i2 < M) {
dg_sum2 += dg == nullptr ? T_ACC(0)
: static_cast<T_ACC>(dY[index2]) *
(static_cast<T_ACC>(X[index2]) - static_cast<T_ACC>(mean[i2])) *
static_cast<T_ACC>(rstd[i2]);
db_sum2 += db == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index2]);
}
}
}
g_shared[threadIdx.y][threadIdx.x] = dg_sum1;
g_shared[threadIdx.y + blockDim.y][threadIdx.x] = dg_sum2;
b_shared[threadIdx.y][threadIdx.x] = db_sum1;
b_shared[threadIdx.y + blockDim.y][threadIdx.x] = db_sum2;
__syncthreads();
T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y];
T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y];
sum1 = WarpReduceSum<16>(sum1);
sum2 = WarpReduceSum<16>(sum2);
if (threadIdx.x == 0) {
const int64_t j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) {
if (dg != nullptr) {
dg[j] = static_cast<scalar_t>(sum1);
}
if (db != nullptr) {
db[j] = static_cast<scalar_t>(sum2);
}
}
}
sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y];
sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y];
sum1 = WarpReduceSum<16>(sum1);
sum2 = WarpReduceSum<16>(sum2);
if (threadIdx.x == 0) {
const int64_t j = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y;
if (j < N) {
if (dg != nullptr) {
dg[j] = static_cast<scalar_t>(sum1);
}
if (db != nullptr) {
db[j] = static_cast<scalar_t>(sum2);
}
}
}
}
template <typename scalar_t>
__global__ void LayerNormBackward_kernel(IndexType N,const scalar_t* dY,const scalar_t* X,const scalar_t* gamma,const acc_type<scalar_t, true>* mean,
const acc_type<scalar_t, true>* rstd, scalar_t* dX, const scalar_t* add_to_output)
{
using T_ACC = acc_type<scalar_t, true>;
__shared__ T_ACC ds_shared[C10_WARP_SIZE];
__shared__ T_ACC db_shared[C10_WARP_SIZE];
const IndexType i = blockIdx.x;
T_ACC sum1 = 0;
T_ACC sum2 = 0;
#pragma unroll
for (IndexType j = threadIdx.x; j < N; j += blockDim.x) {
const IndexType index = i * N + j;
const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[j]);
sum1 += static_cast<T_ACC>(dY[index]) * static_cast<T_ACC>(X[index]) * gamma_v;
sum2 += static_cast<T_ACC>(dY[index]) * gamma_v;
}
sum1 = BlockReduceSum<T_ACC>(sum1, ds_shared);
sum2 = BlockReduceSum<T_ACC>(sum2, db_shared);
const T_ACC s = T_ACC(1) / static_cast<T_ACC>(N);
__shared__ T_ACC b;
__shared__ T_ACC c;
if (threadIdx.x == 0) {
b = (sum2 * static_cast<T_ACC>(mean[i]) - sum1) * static_cast<T_ACC>(rstd[i]) * static_cast<T_ACC>(rstd[i]) *static_cast<T_ACC>(rstd[i]) * s;
c = -(b * static_cast<T_ACC>(mean[i]) + sum2 * static_cast<T_ACC>(rstd[i]) * s);
}
__syncthreads();
#pragma unroll
for (IndexType j = threadIdx.x; j < N; j += blockDim.x) {
const IndexType index = i * N + j;
const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[j]);
dX[index] = static_cast<scalar_t>(static_cast<T_ACC>(rstd[i]) * static_cast<T_ACC>(dY[index]) * gamma_v + b * static_cast<T_ACC>(X[index]) + c
+ (add_to_output == nullptr ? T_ACC(0) : static_cast<T_ACC>(add_to_output[index])));
}
}
template <typename T>
void LayerNormBackwardKernelImplInternal(
oneflow::ep::Stream* stream,
const T* dY,
const T* X,
const acc_type<T, true>* mean,
const acc_type<T, true>* rstd,
const T* gamma,
int64_t M,
int64_t N,
T* dX,
const T* add_to_output) {
using T_ACC = acc_type<T, true>;
const T* dY_data = dY;
const T* X_data = X;
const T_ACC* mean_data = mean;
const T_ACC* rstd_data = rstd;
const T* gamma_data = gamma;
T* dX_data = dX;
const T* add_to_output_data = add_to_output;
hipStream_t cuda_stream = stream->As<oneflow::ep::CudaStream>()->cuda_stream();
if (dX_data != nullptr) {
LayerNormBackward_kernel<T><<<M, BlockReduceNumThreads, 0, cuda_stream>>>(
N, dY_data, X_data,gamma_data,mean_data,rstd_data,dX_data,add_to_output_data);
}
}
template <typename T>
void LayerNormBackwardKernelImplInternalParam(
oneflow::ep::Stream* stream,
const T* dY,
const T* X,
const acc_type<T, true>* mean,
const acc_type<T, true>* rstd,
int64_t M,
int64_t N,
T* dgamma,
T* dbeta) {
using T_ACC = acc_type<T, true>;
const T* dY_data = dY;
const T* X_data = X;
const T_ACC* mean_data = mean;
const T_ACC* rstd_data = rstd;
hipStream_t cuda_stream = stream->As<oneflow::ep::CudaStream>()->cuda_stream();
T* dgamma_data = dgamma;
T* dbeta_data = dbeta;
if (M < 512) {
// For small batch size, do colwise reduce directly.
const int64_t B = (N + NumThreads - 1) / NumThreads;
GammaBetaBackwardSimple<T>
<<<B, NumThreads, 0, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
} else {
const int64_t B =
(N + ColwiseReduceTileSize - 1) / ColwiseReduceTileSize;
constexpr int kThreadX = ColwiseReduceTileSize;
constexpr int kThreadY = ColwiseReduceTileSize / 2;
GammaBetaBackward<T>
<<<B, dim3(kThreadX, kThreadY), 0, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
}
}
namespace oneflow {
template<typename T>
class LayerNormGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {
public:
LayerNormGpuKernel() = default;
~LayerNormGpuKernel() = default;
private:
using user_op::OpKernel::Compute;
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0);
double epsilon = ctx->Attr<double>("epsilon");
int64_t num_instances = mean->shape_view().elem_cnt();
int64_t norm_size = x->shape_view().elem_cnt() / num_instances;
const T* gamma_ptr = nullptr;
const T* beta_ptr = nullptr;
if (ctx->has_input("gamma", 0)) {
const user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0);
gamma_ptr = gamma->dptr<T>();
CHECK_EQ(gamma->shape_view().elem_cnt(), norm_size);
}
if (ctx->has_input("beta", 0)) { beta_ptr = ctx->Tensor4ArgNameAndIndex("beta", 0)->dptr<T>(); }
// DispatchLayerNormForwardGpu<T>(ctx->stream(), num_instances, norm_size, epsilon, x->dptr<T>(),
// gamma_ptr, beta_ptr, y->mut_dptr<T>(), mean, inv_variance);
using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;
LayerNormKernelImplInternal<T>(ctx->stream(), x->dptr<T>(), gamma_ptr, beta_ptr, num_instances, norm_size, epsilon,
y->mut_dptr<T>(), mean->mut_dptr<ComputeType>(), inv_variance->mut_dptr<ComputeType>());
};
};
#define REGISTER_LAYER_NORM_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm") \
.SetCreateFn<LayerNormGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value));
REGISTER_LAYER_NORM_CUDA_KERNEL(float)
REGISTER_LAYER_NORM_CUDA_KERNEL(double)
REGISTER_LAYER_NORM_CUDA_KERNEL(half)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_CUDA_KERNEL(nv_bfloat16)
#endif
template<typename T>
class LayerNormGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {
public:
LayerNormGradGpuKernel() = default;
~LayerNormGradGpuKernel() = default;
private:
using user_op::OpKernel::Compute;
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0);
user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
int64_t num_instances = mean->shape_view().elem_cnt();
int64_t norm_size = x->shape_view().elem_cnt() / num_instances;
const T* gamma_ptr = nullptr;
if (ctx->has_input("gamma", 0)) {
gamma_ptr = ctx->Tensor4ArgNameAndIndex("gamma", 0)->dptr<T>();
}
const T* add_to_output_ptr = nullptr;
if (ctx->has_input("_add_to_output", 0)) {
const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0);
CHECK_EQ(add_to_output->data_type(), dx->data_type());
CHECK_EQ(add_to_output->shape_view(), dx->shape_view());
add_to_output_ptr = add_to_output->dptr<T>();
}
// LaunchLayerNormBackward<T>(ctx->stream(), num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(),
// mean, inv_variance, gamma_ptr, add_to_output_ptr, dx->mut_dptr<T>());
using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;
LayerNormBackwardKernelImplInternal<T>(ctx->stream(), dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(), inv_variance->dptr<ComputeType>(),
gamma_ptr, num_instances, norm_size, dx->mut_dptr<T>(), add_to_output_ptr);
};
};
#define REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm_grad") \
.SetCreateFn<LayerNormGradGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value)) \
.SetInplaceProposalFn( \
[](const user_op::InferContext& ctx, \
const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> { \
if (ctx.has_input("_add_to_output", 0)) { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "_add_to_output", 0, true)); \
} \
return Maybe<void>::Ok(); \
});
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(float)
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(double)
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(half)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
#endif
template<typename T>
class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
public user_op::CudaGraphSupport {
public:
LayerNormParamGradGpuKernel() = default;
~LayerNormParamGradGpuKernel() = default;
private:
using user_op::OpKernel::Compute;
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0);
int64_t num_instances = mean->shape_view().elem_cnt();
int64_t norm_size = x->shape_view().elem_cnt() / num_instances;
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
// const DataType data_type = dy->data_type();
// const int grid_dim_x = (norm_size + tile_size - 1) / tile_size;
// const int grid_dim_y = GetGirdDimY<T>(num_instances, norm_size);
// const size_t tmp_gamma_diff_size = grid_dim_y * norm_size * sizeof(T);
// T* tmp_gamma_diff_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr());
// T* tmp_beta_diff_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + tmp_gamma_diff_size);
// T* reduce_buf_ptr =
// reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + 2 * tmp_gamma_diff_size);
using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;
// LayerNormParamGrad<T, ComputeType><<<dim3(grid_dim_x, grid_dim_y), dim3(32, 32 / num_per_block),
// 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(
// num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(),
// inv_variance->dptr<ComputeType>(), tmp_gamma_diff_ptr, tmp_beta_diff_ptr);
// const int32_t m = norm_size;
// const int32_t n = 1;
// const int32_t k = grid_dim_y;
// std::unique_ptr<ep::primitive::Fill> fill =
// ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),
// data_type);
// CHECK(fill);
// fill->Launch(ctx->stream(), reduce_buf_ptr, 1.0, grid_dim_y);
// std::unique_ptr<ep::primitive::Matmul> matmul =
// ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(
// ctx->stream()->device_type(), data_type, ep::primitive::BlasTransposeType::T,
// ep::primitive::BlasTransposeType::N);
// CHECK(matmul);
// if (ctx->has_output("gamma_diff", 0)) {
// user_op::Tensor* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0);
// matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_gamma_diff_ptr, reduce_buf_ptr, 0.0,
// gamma_diff->mut_dptr());
// }
// if (ctx->has_output("beta_diff", 0)) {
// user_op::Tensor* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0);
// matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_beta_diff_ptr, reduce_buf_ptr, 0.0,
// beta_diff->mut_dptr());
// }
T* gamma_diff_ptr = nullptr;
T* beta_diff_ptr = nullptr;
if (ctx->has_output("gamma_diff", 0)) {
gamma_diff_ptr = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0)->mut_dptr<T>();
}
if (ctx->has_output("beta_diff", 0)) {
beta_diff_ptr = ctx->Tensor4ArgNameAndIndex("beta_diff", 0)->mut_dptr<T>();
}
LayerNormBackwardKernelImplInternalParam<T>(ctx->stream(), dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(), inv_variance->dptr<ComputeType>(),
num_instances, norm_size, gamma_diff_ptr, beta_diff_ptr);
};
};
#define REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm_param_grad") \
.SetCreateFn<LayerNormParamGradGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) { \
const int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis"); \
const bool has_gamma_diff = ctx->has_output("gamma_diff", 0); \
const bool has_beta_diff = ctx->has_output("beta_diff", 0); \
const auto& dy = ctx->InputTensorDesc("dy", 0); \
const int64_t num_instances = dy.shape().Count(0, begin_params_axis); \
const int64_t norm_size = dy.shape().Count(begin_params_axis); \
const int grid_dim_y = num_instances; \
size_t tmp_buffer_size = (2 * grid_dim_y * norm_size + grid_dim_y) * sizeof(dtype); \
return tmp_buffer_size; \
});
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(float)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(double)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(half) REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(half)
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(nv_bfloat16) REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(nv_bfloat16)
#endif #endif
} } // namespace oneflow
#endif
...@@ -28,7 +28,7 @@ limitations under the License. ...@@ -28,7 +28,7 @@ limitations under the License.
#elif defined(__HIPCC__) #elif defined(__HIPCC__)
#include <hip/hsa_detail/math_functions.h> #include <hip/math_functions.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__) #if defined(__HIP_DEVICE_COMPILE__)
......
...@@ -882,6 +882,22 @@ namespace oneflow { ...@@ -882,6 +882,22 @@ namespace oneflow {
namespace { namespace {
template<typename T>
void printTensor(const std::string& str, const T* devTensor, size_t size) {
T* hostTensor;
hostTensor = (T*)malloc(size * sizeof(T));
hipMemcpy(hostTensor, devTensor, size * sizeof(T), hipMemcpyDeviceToHost);
std::cout << str << ": ";
for(int i; i<size; i++) {
if (i % 16 == 0) {
std::cout << std::endl;
}
std::cout << hostTensor[i] << ", ";
}
std::cout << str << ": finish" << std::endl;
free(hostTensor);
}
hipdnnBatchNormMode_t getCudnnBatchNormMode(const int64_t dim) { hipdnnBatchNormMode_t getCudnnBatchNormMode(const int64_t dim) {
if (dim == 2) { if (dim == 2) {
return HIPDNN_BATCHNORM_PER_ACTIVATION; return HIPDNN_BATCHNORM_PER_ACTIVATION;
...@@ -969,6 +985,15 @@ class CudnnTensorDescHelper final { ...@@ -969,6 +985,15 @@ class CudnnTensorDescHelper final {
int32_t param_size_ = 0; int32_t param_size_ = 0;
}; };
size_t InferInferTmpSize(user_op::InferContext* ctx) {
const auto& y = ctx->OutputTensorDesc("y", 0);
if (ctx->has_input("_add_to_output", 0)) {
return y.shape().elem_cnt() * GetSizeOfDataType(y.data_type());
} else {
return 1;
}
}
size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_type, size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_type,
const int32_t axis) { const int32_t axis) {
return 1; return 1;
...@@ -976,8 +1001,13 @@ size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_typ ...@@ -976,8 +1001,13 @@ size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_typ
size_t InferTrainTmpSize(user_op::InferContext* ctx) { size_t InferTrainTmpSize(user_op::InferContext* ctx) {
const auto& x = ctx->InputTensorDesc("x", 0); const auto& x = ctx->InputTensorDesc("x", 0);
const auto& y = ctx->OutputTensorDesc("y", 0);
const auto axis = ctx->Attr<int32_t>("axis"); const auto axis = ctx->Attr<int32_t>("axis");
return InferTrainWorkspaceSize(x.shape(), x.data_type(), axis); if (ctx->has_input("_add_to_output", 0)) {
return y.shape().elem_cnt() * GetSizeOfDataType(y.data_type());
} else {
return InferTrainWorkspaceSize(x.shape(), x.data_type(), axis);
}
} }
size_t InferGradWorkspaceSize(const ShapeView& x_shape, const DataType data_type, size_t InferGradWorkspaceSize(const ShapeView& x_shape, const DataType data_type,
...@@ -1016,6 +1046,9 @@ class NormalizationInferenceKernel final : public user_op::OpKernel, ...@@ -1016,6 +1046,9 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
const auto axis = ctx->Attr<int32_t>("axis"); const auto axis = ctx->Attr<int32_t>("axis");
const auto epsilon = ctx->Attr<float>("epsilon"); const auto epsilon = ctx->Attr<float>("epsilon");
auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
void* add_to_output_dev = tmp_buffer->mut_dptr<void>();
const DataType data_type = x->data_type(); const DataType data_type = x->data_type();
CHECK_EQ(x->shape_view(), y->shape_view()); CHECK_EQ(x->shape_view(), y->shape_view());
CHECK_EQ(y->data_type(), data_type); CHECK_EQ(y->data_type(), data_type);
...@@ -1030,17 +1063,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel, ...@@ -1030,17 +1063,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
desc_helper.CheckParamTensor(moving_variance); desc_helper.CheckParamTensor(moving_variance);
const void* sp_alpha = CudnnSPOnePtr(data_type); const void* sp_alpha = CudnnSPOnePtr(data_type);
const void* sp_beta; const void* sp_beta = CudnnSPZeroPtr(data_type);
if (ctx->has_input("_add_to_output", 0)) { if (ctx->has_input("_add_to_output", 0)) {
const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0);
CHECK_EQ(add_to_output->data_type(), y->data_type()); CHECK_EQ(add_to_output->data_type(), y->data_type());
CHECK_EQ(add_to_output->shape_view(), y->shape_view()); CHECK_EQ(add_to_output->shape_view(), y->shape_view());
Memcpy<DeviceType::kCUDA>( Memcpy<DeviceType::kCUDA>(
ctx->stream(), y->mut_dptr<void>(), add_to_output->dptr<void>(), ctx->stream(), add_to_output_dev, add_to_output->dptr<void>(),
add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));
sp_beta = CudnnSPOnePtr(data_type);
} else {
sp_beta = CudnnSPZeroPtr(data_type);
} }
OF_CUDNN_CHECK(hipdnnBatchNormalizationForwardInference( OF_CUDNN_CHECK(hipdnnBatchNormalizationForwardInference(
...@@ -1048,6 +1079,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel, ...@@ -1048,6 +1079,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(),
desc_helper.param_desc(), gamma->dptr(), beta->dptr(), moving_mean->dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(), moving_mean->dptr(),
moving_variance->dptr(), epsilon)); moving_variance->dptr(), epsilon));
if (ctx->has_input("_add_to_output", 0)) {
sp_beta = CudnnSPOnePtr(data_type);
OF_CUDNN_CHECK(hipdnnAddTensor(ctx->stream()->As<ep::CudaStream>()->cudnn_handle(),
sp_alpha, desc_helper.xy_desc(),
add_to_output_dev, sp_beta, desc_helper.xy_desc(),
y->mut_dptr()));
}
} }
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
...@@ -1057,6 +1097,7 @@ REGISTER_USER_KERNEL("normalization") ...@@ -1057,6 +1097,7 @@ REGISTER_USER_KERNEL("normalization")
.SetCreateFn<NormalizationInferenceKernel>() .SetCreateFn<NormalizationInferenceKernel>()
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)
&& (user_op::HobAttr<bool>("training") == false)) && (user_op::HobAttr<bool>("training") == false))
.SetInferTmpSizeFn(InferInferTmpSize)
.SetInplaceProposalFn([](const user_op::InferContext& ctx, .SetInplaceProposalFn([](const user_op::InferContext& ctx,
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {
if (ctx.has_input("_add_to_output", 0)) { if (ctx.has_input("_add_to_output", 0)) {
...@@ -1068,76 +1109,78 @@ REGISTER_USER_KERNEL("normalization") ...@@ -1068,76 +1109,78 @@ REGISTER_USER_KERNEL("normalization")
constexpr int64_t kCudaWarpSize = 64; constexpr int64_t kCudaWarpSize = 64;
template<typename T> template<typename T>
__global__ void ReluGpu(int64_t n, const T* x, T* y, int32_t* mask) { __global__ void ReluGpu(int64_t n, const T* x, T* y, int64_t* mask) {
const int32_t lane_id = threadIdx.x % kCudaWarpSize; const int32_t lane_id = threadIdx.x % kCudaWarpSize;
const T zero = static_cast<T>(0.f); const T zero = static_cast<T>(0.f);
CUDA_1D_KERNEL_LOOP(i, n) { CUDA_1D_KERNEL_LOOP(i, n) {
const T x_val = x[i]; const T x_val = x[i];
const bool is_positive = (x_val > zero); const bool is_positive = (x_val > zero);
int32_t warp_mask = __ballot(static_cast<int>(is_positive)); unsigned long long int warp_mask_tmp = __ballot(static_cast<int>(is_positive));
if (lane_id == 0) { mask[i / kCudaWarpSize] = warp_mask; } int64_t* warp_mask = reinterpret_cast<int64_t*>(&warp_mask_tmp);
if (lane_id == 0) { mask[i / kCudaWarpSize] = *warp_mask; }
y[i] = is_positive ? x_val : zero; y[i] = is_positive ? x_val : zero;
} }
} }
template<typename T> template<typename T>
__global__ void AddReluGpu(int64_t n, const T* x, const T* addend, T* y, int32_t* mask) { __global__ void AddReluGpu(int64_t n, const T* x, const T* addend, T* y, int64_t* mask) {
const int32_t lane_id = threadIdx.x % kCudaWarpSize; const int32_t lane_id = threadIdx.x % kCudaWarpSize;
const T zero = static_cast<T>(0.f); const T zero = static_cast<T>(0.f);
CUDA_1D_KERNEL_LOOP(i, n) { CUDA_1D_KERNEL_LOOP(i, n) {
const T sum = x[i] + addend[i]; const T sum = x[i] + addend[i];
const bool is_positive = (sum > zero); const bool is_positive = (sum > zero);
int32_t warp_mask = __ballot(static_cast<int>(is_positive)); unsigned long long int warp_mask_tmp = __ballot(static_cast<int>(is_positive));
if (lane_id == 0) { mask[i / kCudaWarpSize] = warp_mask; } int64_t* warp_mask = reinterpret_cast<int64_t*>(&warp_mask_tmp);
if (lane_id == 0) { mask[i / kCudaWarpSize] = *warp_mask; }
y[i] = is_positive ? sum : zero; y[i] = is_positive ? sum : zero;
} }
} }
template<typename T> template<typename T>
void Relu(ep::Stream* stream, int64_t n, const T* x, T* y, int32_t* mask) { void Relu(ep::Stream* stream, int64_t n, const T* x, T* y, int64_t* mask) {
ReluGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ReluGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, y, mask); stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, y, mask);
} }
template<typename T> template<typename T>
void AddRelu(ep::Stream* stream, int64_t n, const T* x, const T* addend, T* y, int32_t* mask) { void AddRelu(ep::Stream* stream, int64_t n, const T* x, const T* addend, T* y, int64_t* mask) {
AddReluGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, AddReluGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, addend, y, mask); stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, addend, y, mask);
} }
template<typename T> template<typename T>
__global__ void ReluBackwardGpu(int64_t n, const int32_t* mask, const T* dy, T* addend_diff) { __global__ void ReluBackwardGpu(int64_t n, const int64_t* mask, const T* dy, T* addend_diff) {
int32_t lane_id = threadIdx.x % kCudaWarpSize; int64_t lane_id = threadIdx.x % kCudaWarpSize;
CUDA_1D_KERNEL_LOOP(i, n) { CUDA_1D_KERNEL_LOOP(i, n) {
int32_t mask_val = mask[i / kCudaWarpSize]; int64_t mask_val = mask[i / kCudaWarpSize];
bool is_positive = mask_val & (1 << lane_id); bool is_positive = mask_val & ((int64_t)1 << lane_id);
addend_diff[i] = static_cast<T>(is_positive) * dy[i]; addend_diff[i] = static_cast<T>(is_positive) * dy[i];
} }
} }
#if CUDA_VERSION >= 11000 // #if CUDA_VERSION >= 11000
template<> // template<>
__global__ void ReluBackwardGpu<nv_bfloat16>(int64_t n, const int32_t* mask, const nv_bfloat16* dy, // __global__ void ReluBackwardGpu<nv_bfloat16>(int64_t n, const int32_t* mask, const nv_bfloat16* dy,
nv_bfloat16* addend_diff) { // nv_bfloat16* addend_diff) {
int32_t lane_id = threadIdx.x % kCudaWarpSize; // int32_t lane_id = threadIdx.x % kCudaWarpSize;
CUDA_1D_KERNEL_LOOP(i, n) { // CUDA_1D_KERNEL_LOOP(i, n) {
int32_t mask_val = mask[i / kCudaWarpSize]; // int32_t mask_val = mask[i / kCudaWarpSize];
bool is_positive = mask_val & (1 << lane_id); // bool is_positive = mask_val & (1 << lane_id);
addend_diff[i] = static_cast<nv_bfloat16>(static_cast<float>(is_positive)) * dy[i]; // addend_diff[i] = static_cast<nv_bfloat16>(static_cast<float>(is_positive)) * dy[i];
} // }
} // }
#endif // #endif
template<typename T> template<typename T>
void ReluBackward(ep::Stream* stream, int64_t n, const int32_t* mask, const T* dy, T* addend_diff) { void ReluBackward(ep::Stream* stream, int64_t n, const int64_t* mask, const T* dy, T* addend_diff) {
ReluBackwardGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ReluBackwardGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(n, mask, dy, addend_diff); stream->As<ep::CudaStream>()->cuda_stream()>>>(n, mask, dy, addend_diff);
} }
void Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x, void* y, void Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x, void* y,
int32_t* mask) { int64_t* mask) {
if (data_type == kFloat) { if (data_type == kFloat) {
Relu<float>(stream, n, reinterpret_cast<const float*>(x), reinterpret_cast<float*>(y), mask); Relu<float>(stream, n, reinterpret_cast<const float*>(x), reinterpret_cast<float*>(y), mask);
} else if (data_type == kDouble) { } else if (data_type == kDouble) {
...@@ -1156,7 +1199,7 @@ void Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x ...@@ -1156,7 +1199,7 @@ void Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x
} }
} }
void AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x, void AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x,
const void* addend, void* y, int32_t* mask) { const void* addend, void* y, int64_t* mask) {
if (data_type == kFloat) { if (data_type == kFloat) {
AddRelu<float>(stream, n, reinterpret_cast<const float*>(x), AddRelu<float>(stream, n, reinterpret_cast<const float*>(x),
reinterpret_cast<const float*>(addend), reinterpret_cast<float*>(y), mask); reinterpret_cast<const float*>(addend), reinterpret_cast<float*>(y), mask);
...@@ -1178,7 +1221,7 @@ void AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void ...@@ -1178,7 +1221,7 @@ void AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
} }
void ReluBackward(ep::Stream* stream, int64_t n, const DataType data_type, const int32_t* mask, void ReluBackward(ep::Stream* stream, int64_t n, const DataType data_type, const int64_t* mask,
const void* dy, void* addend_diff) { const void* dy, void* addend_diff) {
if (data_type == kFloat) { if (data_type == kFloat) {
ReluBackward<float>(stream, n, mask, reinterpret_cast<const float*>(dy), ReluBackward<float>(stream, n, mask, reinterpret_cast<const float*>(dy),
...@@ -1225,6 +1268,9 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op: ...@@ -1225,6 +1268,9 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
hipdnnBatchNormMode_t mode = getCudnnBatchNormMode(x->shape_view().NumAxes()); hipdnnBatchNormMode_t mode = getCudnnBatchNormMode(x->shape_view().NumAxes());
const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, mode); const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, mode);
auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
void* add_to_output_dev = tmp_buffer->mut_dptr<void>();
const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0);
const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0); const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0);
auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0); auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
...@@ -1244,17 +1290,15 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op: ...@@ -1244,17 +1290,15 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
desc_helper.CheckParamTensor(moving_variance); desc_helper.CheckParamTensor(moving_variance);
} }
const void* sp_alpha = CudnnSPOnePtr(data_type); const void* sp_alpha = CudnnSPOnePtr(data_type);
const void* sp_beta; const void* sp_beta = CudnnSPZeroPtr(data_type);
if (ctx->has_input("_add_to_output", 0)) { if (ctx->has_input("_add_to_output", 0)) {
const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0);
CHECK_EQ(add_to_output->data_type(), y->data_type()); CHECK_EQ(add_to_output->data_type(), y->data_type());
CHECK_EQ(add_to_output->shape_view(), y->shape_view()); CHECK_EQ(add_to_output->shape_view(), y->shape_view());
Memcpy<DeviceType::kCUDA>( Memcpy<DeviceType::kCUDA>(
ctx->stream(), y->mut_dptr<void>(), add_to_output->dptr<void>(), ctx->stream(), add_to_output_dev, add_to_output->dptr<void>(),
add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));
sp_beta = CudnnSPOnePtr(data_type);
} else {
sp_beta = CudnnSPZeroPtr(data_type);
} }
OF_CUDNN_CHECK(hipdnnBatchNormalizationForwardTraining( OF_CUDNN_CHECK(hipdnnBatchNormalizationForwardTraining(
...@@ -1265,6 +1309,14 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op: ...@@ -1265,6 +1309,14 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
moving_variance ? moving_variance->mut_dptr() : NULL, epsilon, mean->mut_dptr(), moving_variance ? moving_variance->mut_dptr() : NULL, epsilon, mean->mut_dptr(),
inv_variance->mut_dptr())); inv_variance->mut_dptr()));
if (ctx->has_input("_add_to_output", 0)) {
sp_beta = CudnnSPOnePtr(data_type);
OF_CUDNN_CHECK(hipdnnAddTensor(ctx->stream()->As<ep::CudaStream>()->cudnn_handle(),
sp_alpha, desc_helper.xy_desc(),
add_to_output_dev, sp_beta, desc_helper.xy_desc(),
y->mut_dptr()));
}
if (ctx->op_type_name() == "normalization_add_relu") { if (ctx->op_type_name() == "normalization_add_relu") {
CHECK(!ctx->has_input("_add_to_output", 0)); CHECK(!ctx->has_input("_add_to_output", 0));
const int64_t elem_cnt = x->shape_view().elem_cnt(); const int64_t elem_cnt = x->shape_view().elem_cnt();
...@@ -1272,10 +1324,10 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op: ...@@ -1272,10 +1324,10 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
if (ctx->has_input("addend", 0)) { if (ctx->has_input("addend", 0)) {
const auto* addend = ctx->Tensor4ArgNameAndIndex("addend", 0); const auto* addend = ctx->Tensor4ArgNameAndIndex("addend", 0);
AddRelu(ctx->stream(), elem_cnt, data_type, y->dptr(), addend->dptr(), y->mut_dptr(), AddRelu(ctx->stream(), elem_cnt, data_type, y->dptr(), addend->dptr(), y->mut_dptr(),
mask->mut_dptr<int32_t>()); mask->mut_dptr<int64_t>());
} else { } else {
Relu(ctx->stream(), elem_cnt, data_type, y->dptr(), y->mut_dptr(), Relu(ctx->stream(), elem_cnt, data_type, y->dptr(), y->mut_dptr(),
mask->mut_dptr<int32_t>()); mask->mut_dptr<int64_t>());
} }
} }
} }
...@@ -1351,7 +1403,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel, ...@@ -1351,7 +1403,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel,
user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
if (ctx->has_output("addend_diff", 0)) { if (ctx->has_output("addend_diff", 0)) {
user_op::Tensor* addend_diff = ctx->Tensor4ArgNameAndIndex("addend_diff", 0); user_op::Tensor* addend_diff = ctx->Tensor4ArgNameAndIndex("addend_diff", 0);
ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr<int32_t>(), dy->dptr(), ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr<int64_t>(), dy->dptr(),
addend_diff->mut_dptr()); addend_diff->mut_dptr());
bn_workspace_ptr = tmp_buffer->mut_dptr(); bn_workspace_ptr = tmp_buffer->mut_dptr();
bn_workspace_size = tmp_buffer->shape_view().elem_cnt(); bn_workspace_size = tmp_buffer->shape_view().elem_cnt();
...@@ -1361,7 +1413,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel, ...@@ -1361,7 +1413,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel,
const size_t relu_dx_size = const size_t relu_dx_size =
GetCudaAlignedSize(dy->shape_view().elem_cnt() * GetSizeOfDataType(data_type)); GetCudaAlignedSize(dy->shape_view().elem_cnt() * GetSizeOfDataType(data_type));
CHECK_GE(tmp_buffer_size, relu_dx_size); CHECK_GE(tmp_buffer_size, relu_dx_size);
ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr<int32_t>(), dy->dptr(), ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr<int64_t>(), dy->dptr(),
tmp_buffer->mut_dptr()); tmp_buffer->mut_dptr());
bn_workspace_ptr = tmp_buffer->mut_dptr<char>() + relu_dx_size; bn_workspace_ptr = tmp_buffer->mut_dptr<char>() + relu_dx_size;
bn_workspace_size = tmp_buffer_size - relu_dx_size; bn_workspace_size = tmp_buffer_size - relu_dx_size;
...@@ -1393,231 +1445,6 @@ REGISTER_USER_KERNEL("normalization_add_relu_grad") ...@@ -1393,231 +1445,6 @@ REGISTER_USER_KERNEL("normalization_add_relu_grad")
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)) .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA))
.SetInferTmpSizeFn(InferGradTmpSize); .SetInferTmpSizeFn(InferGradTmpSize);
#if (HIPDNN_VERSION >= 7401)
size_t InferFusedNormalizationAddReluTmpSize(user_op::InferContext* ctx) {
const auto& x = ctx->InputTensorDesc("x", 0);
const auto axis = ctx->Attr<int32_t>("axis");
const CudnnTensorDescHelper desc_helper(x.shape(), x.data_type(), axis,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT);
size_t size_in_bytes;
hipdnnHandle_t handle = Singleton<CudnnHandlePool>::Get()->Get();
CudnnActivationDesc activation_desc(HIPDNN_ACTIVATION_RELU, HIPDNN_PROPAGATE_NAN, 0);
cudnnBatchNormOps_t ops;
hipdnnTensorDescriptor_t z_desc;
if (ctx->has_input("addend", 0)) {
ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
z_desc = desc_helper.xy_desc();
} else {
ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
z_desc = nullptr;
}
OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
handle, HIPDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, desc_helper.xy_desc(), z_desc,
desc_helper.xy_desc(), desc_helper.param_desc(), activation_desc.Get(), &size_in_bytes));
Singleton<CudnnHandlePool>::Get()->Put(handle);
return std::max(size_in_bytes, static_cast<size_t>(1));
}
size_t InferFusedNormalizationAddReluGradTmpSize(user_op::InferContext* ctx) {
const auto& x = ctx->InputTensorDesc("x", 0);
const auto axis = ctx->Attr<int32_t>("axis");
const CudnnTensorDescHelper desc_helper(x.shape(), x.data_type(), axis,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT);
size_t size_in_bytes;
hipdnnHandle_t handle = Singleton<CudnnHandlePool>::Get()->Get();
CudnnActivationDesc activation_desc(HIPDNN_ACTIVATION_RELU, HIPDNN_PROPAGATE_NAN, 0);
cudnnBatchNormOps_t ops;
hipdnnTensorDescriptor_t z_desc;
if (ctx->has_output("addend_diff", 0)) {
ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
z_desc = desc_helper.xy_desc();
} else {
ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
z_desc = nullptr;
}
OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
handle, HIPDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, desc_helper.xy_desc(), desc_helper.xy_desc(),
desc_helper.xy_desc(), z_desc, desc_helper.xy_desc(), desc_helper.param_desc(),
activation_desc.Get(), &size_in_bytes));
Singleton<CudnnHandlePool>::Get()->Put(handle);
return std::max(size_in_bytes, static_cast<size_t>(1));
}
class FusedNormalizationAddReluKernel final : public user_op::OpKernel,
public user_op::CudaGraphSupport {
public:
FusedNormalizationAddReluKernel() = default;
~FusedNormalizationAddReluKernel() override = default;
private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx) const override {
const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0);
auto* y = ctx->Tensor4ArgNameAndIndex("y", 0);
const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0);
const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0);
auto* moving_mean = ctx->Tensor4ArgNameAndIndex("moving_mean", 0);
auto* moving_variance = ctx->Tensor4ArgNameAndIndex("moving_variance", 0);
const auto axis = ctx->Attr<int32_t>("axis");
const auto epsilon = ctx->Attr<float>("epsilon");
const auto momentum = ctx->Attr<float>("momentum");
auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
auto* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0);
auto* reserve_space = ctx->Tensor4ArgNameAndIndex("reserve_space", 0);
auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const DataType data_type = x->data_type();
CHECK_EQ(x->shape_view(), y->shape_view());
CHECK_EQ(y->data_type(), data_type);
CHECK_GE(axis, 0);
CHECK_LT(axis, x->shape_view().NumAxes());
const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT);
desc_helper.CheckParamTensor(gamma);
desc_helper.CheckParamTensor(beta);
desc_helper.CheckParamTensor(moving_mean);
desc_helper.CheckParamTensor(moving_variance);
desc_helper.CheckParamTensor(mean);
desc_helper.CheckParamTensor(inv_variance);
CudnnActivationDesc activation_desc(HIPDNN_ACTIVATION_RELU, HIPDNN_PROPAGATE_NAN, 0);
hipdnnTensorDescriptor_t z_desc;
const void* z_ptr;
cudnnBatchNormOps_t ops;
if (ctx->has_input("addend", 0)) {
z_desc = desc_helper.xy_desc();
z_ptr = ctx->Tensor4ArgNameAndIndex("addend", 0)->dptr();
ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
} else {
z_desc = nullptr;
z_ptr = nullptr;
ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
}
size_t min_workspace_size;
OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, desc_helper.xy_desc(), z_desc, desc_helper.xy_desc(), desc_helper.param_desc(),
activation_desc.Get(), &min_workspace_size));
const size_t workspace_size = tmp_buffer->shape_view().elem_cnt();
CHECK_GE(workspace_size, min_workspace_size);
size_t min_reserve_space_size;
OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, activation_desc.Get(), desc_helper.xy_desc(), &min_reserve_space_size));
const size_t reserve_space_size = reserve_space->shape_view().elem_cnt();
CHECK_GE(reserve_space_size, min_reserve_space_size);
OF_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(),
z_desc, z_ptr, desc_helper.xy_desc(), y->mut_dptr(), desc_helper.param_desc(),
gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean->mut_dptr(),
moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), inv_variance->mut_dptr(),
activation_desc.Get(), tmp_buffer->mut_dptr(), workspace_size, reserve_space->mut_dptr(),
reserve_space_size));
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
REGISTER_USER_KERNEL("cudnn_fused_normalization_add_relu")
.SetCreateFn<FusedNormalizationAddReluKernel>()
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA))
.SetInferTmpSizeFn(InferFusedNormalizationAddReluTmpSize);
class FusedNormalizationAddReluGradUserKernel final : public user_op::OpKernel,
public user_op::CudaGraphSupport {
public:
FusedNormalizationAddReluGradUserKernel() = default;
~FusedNormalizationAddReluGradUserKernel() override = default;
private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx) const override {
const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0);
const auto* y = ctx->Tensor4ArgNameAndIndex("y", 0);
auto* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
const auto* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0);
const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0);
auto* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0);
auto* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0);
const auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
const auto* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0);
const auto* reserve_space = ctx->Tensor4ArgNameAndIndex("reserve_space", 0);
auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const auto axis = ctx->Attr<int32_t>("axis");
const auto epsilon = ctx->Attr<float>("epsilon");
const DataType data_type = x->data_type();
CHECK_EQ(dy->shape_view(), x->shape_view());
CHECK_EQ(dy->data_type(), data_type);
CHECK_EQ(dx->shape_view(), x->shape_view());
CHECK_EQ(dx->data_type(), data_type);
CHECK_GE(axis, 0);
CHECK_LT(axis, x->shape_view().NumAxes());
const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT);
desc_helper.CheckParamTensor(gamma);
desc_helper.CheckParamTensor(beta);
desc_helper.CheckParamTensor(gamma_diff);
desc_helper.CheckParamTensor(beta_diff);
desc_helper.CheckParamTensor(mean);
desc_helper.CheckParamTensor(inv_variance);
CudnnActivationDesc activation_desc(HIPDNN_ACTIVATION_RELU, HIPDNN_PROPAGATE_NAN, 0);
hipdnnTensorDescriptor_t dz_desc;
void* dz_ptr;
cudnnBatchNormOps_t ops;
if (ctx->has_output("addend_diff", 0)) {
dz_desc = desc_helper.xy_desc();
dz_ptr = ctx->Tensor4ArgNameAndIndex("addend_diff", 0)->mut_dptr();
ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
} else {
dz_desc = nullptr;
dz_ptr = nullptr;
ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
}
size_t min_workspace_size;
OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, desc_helper.xy_desc(), desc_helper.xy_desc(), desc_helper.xy_desc(), dz_desc,
desc_helper.xy_desc(), desc_helper.param_desc(), activation_desc.Get(),
&min_workspace_size));
const size_t workspace_size = tmp_buffer->shape_view().elem_cnt();
CHECK_GE(workspace_size, min_workspace_size);
size_t min_reserve_space_size;
OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, activation_desc.Get(), desc_helper.xy_desc(), &min_reserve_space_size));
const size_t reserve_space_size = reserve_space->shape_view().elem_cnt();
CHECK_GE(reserve_space_size, min_reserve_space_size);
OF_CUDNN_CHECK(cudnnBatchNormalizationBackwardEx(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), CudnnSPOnePtr(data_type),
CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(),
y->dptr(), desc_helper.xy_desc(), dy->dptr(), dz_desc, dz_ptr, desc_helper.xy_desc(),
dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(),
gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr(),
activation_desc.Get(), tmp_buffer->mut_dptr(), workspace_size,
const_cast<void*>(reserve_space->dptr()), reserve_space_size));
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
REGISTER_USER_KERNEL("cudnn_fused_normalization_add_relu_grad")
.SetCreateFn<FusedNormalizationAddReluGradUserKernel>()
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA))
.SetInferTmpSizeFn(InferFusedNormalizationAddReluGradTmpSize);
#endif
} // namespace } // namespace
} // namespace oneflow } // namespace oneflow
......
...@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() { ...@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
CHECK_EQ_OR_RETURN(reserve_space_bits % split_num, 0); CHECK_EQ_OR_RETURN(reserve_space_bits % split_num, 0);
reserve_space_bits = reserve_space_bits / split_num; reserve_space_bits = reserve_space_bits / split_num;
} }
#ifdef WITH_ROCM
reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 64) / 64)}));
#else
reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 32) / 32)})); reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 32) / 32)}));
#endif
return Maybe<void>::Ok(); return Maybe<void>::Ok();
})(ctx); })(ctx);
} }
...@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() { ...@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,
user_op::TensorDesc* reserve_space) -> Maybe<void> { user_op::TensorDesc* reserve_space) -> Maybe<void> {
const auto& x_desc = ctx->InputTensorDesc("x", 0); const auto& x_desc = ctx->InputTensorDesc("x", 0);
#ifdef WITH_ROCM
reserve_space->set_shape(
Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 64) / 64)}));
#else
reserve_space->set_shape( reserve_space->set_shape(
Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)})); Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)}));
#endif
return Maybe<void>::Ok(); return Maybe<void>::Ok();
})(ctx); })(ctx);
} }
...@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() { ...@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
/* static */ Maybe<void> NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) { /* static */ Maybe<void> NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) {
return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,
user_op::TensorDesc* reserve_space) -> Maybe<void> { user_op::TensorDesc* reserve_space) -> Maybe<void> {
#ifdef WITH_ROCM
reserve_space->set_data_type(DataType::kInt64);
#else
reserve_space->set_data_type(DataType::kInt32); reserve_space->set_data_type(DataType::kInt32);
#endif
return Maybe<void>::Ok(); return Maybe<void>::Ok();
})(ctx); })(ctx);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment