Commit 6046d8fb authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

dtk23.04

parent a715222c
......@@ -334,16 +334,16 @@ elseif(WIN32)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow")
endif()
if (BUILD_ROCM)
# AMD compiler fails to compile these three files with '-O1/2/3'.
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# so '-O0' will override '-O1/2/3'.
set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
PROPERTIES COMPILE_OPTIONS "-O0")
endif()
# if (BUILD_ROCM)
# # AMD compiler fails to compile these three files with '-O1/2/3'.
# # The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# # so '-O0' will override '-O1/2/3'.
# set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
# ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
# ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
# # ${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
# PROPERTIES COMPILE_OPTIONS "-O0")
# endif()
if(BUILD_CUDA)
string(JOIN "," CUDA_REAL_ARCHS ${CUDA_REAL_ARCHS_LIST})
......
......@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n);
template<typename T>
OF_DEVICE_FUNC T DeviceMin(T a, T b) {
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return a < b ? a : b;
#else
return std::min(a, b);
......@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) {
template<typename T>
OF_DEVICE_FUNC T DeviceMax(T a, T b) {
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return a > b ? a : b;
#else
return std::max(a, b);
......
......@@ -54,7 +54,7 @@ struct CastCASImpl {
}
};
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__)) || defined(WITH_ROCM)
template<typename T>
struct CastCASImpl<T, unsigned short int> {
......
......@@ -308,7 +308,8 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
count_shared[wid] = warp_count;
}
__syncthreads();
if (wid == 0) {
#ifdef WITH_ROCM
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
......@@ -318,10 +319,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
#ifdef WITH_ROCM
__syncthreads();
__syncthreads();
if (wid == 0) {
#else
__syncwarp();
if (wid == 0) {
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
warp_count = count_shared[lid];
} else {
warp_mean = static_cast<T>(0);
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
__syncwarp();
#endif
T block_mean = 0;
T block_m2 = 0;
......@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 128;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32;
#endif
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
......@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
......@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
......@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 1024;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32;
#endif
int grid_dim_x;
{
GPU(Error_t) err =
......@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
constexpr int block_size = 128;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32;
#endif
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
......@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
......
......@@ -252,15 +252,21 @@ struct SetContext {
const int lane_age = ages[thread_ctx.lane_id];
const bool lane_hit = (lane_key == key && lane_age != 0);
#ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_hit);
const unsigned long long int hit_mask = __ballot(lane_hit);
if (hit_mask != 0) {
return __ffsll(static_cast<unsigned long long int>(hit_mask)) - 1;
} else {
return -1;
}
#else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_hit);
#endif
if (hit_mask != 0) {
return __ffs(static_cast<int>(hit_mask)) - 1;
} else {
return -1;
}
#endif
}
__device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx,
......@@ -277,15 +283,16 @@ struct SetContext {
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_key == key && lane_age != 0);
const unsigned long long int hit_mask = __ballot(lane_key == key && lane_age != 0);
#else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_key == key && lane_age != 0);
#endif
if (hit_mask != 0) {
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
#ifdef WITH_ROCM
insert_way = __ffsll(static_cast<unsigned long long int>(hit_mask)) - 1;
const int insert_way_age = __shfl(lane_age, insert_way);
#else
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
const int insert_way_age = __shfl_sync(kFullMask, lane_age, insert_way);
#endif
if (lane_age > insert_way_age) {
......@@ -301,12 +308,16 @@ __syncwarp();
}
if (insert_way == -1) {
#ifdef WITH_ROCM
const unsigned valid_mask = __ballot(lane_age != 0);
const unsigned long long int valid_mask = __ballot(lane_age != 0);
#else
const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0);
#endif
if (valid_mask != kFullMask) {
#ifdef WITH_ROCM
insert_way = __popcll(static_cast<unsigned long long int>(valid_mask));
#else
insert_way = __popc(static_cast<int>(valid_mask));
#endif
if (lane_age > 0) {
lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) {
......@@ -329,7 +340,8 @@ __syncwarp();
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM
const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1;
unsigned long long int valid_mask_tmp = __ballot(lane_age == 1);
const int insert_way = __ffsll(valid_mask_tmp) - 1;
#else
const int insert_way = __ffs(__ballot_sync(kFullMask, lane_age == 1)) - 1;
#endif
......@@ -594,7 +606,8 @@ __syncwarp();
warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;
warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;
#ifdef WITH_ROCM
const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0)));
unsigned long long int valid_mask_tmp = __ballot(lane_age != 0);
const int key_count = __popcll(valid_mask_tmp);
#else
const int key_count = __popc(__ballot_sync(kFullMask, lane_age != 0));
#endif
......
......@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/include/primitive//broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
......
......@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
OF_DEVICE_FUNC half operator()(half src) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const float tanh_in =
__half2float(__float2half_rn(alpha) * (src + __float2half_rn(beta) * src * src * src));
const float tanh_out = unary_functor_internal::TanhApprox(tanh_in);
......@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* dst, const half* src) const {
const half2 src2 = *(reinterpret_cast<const half2*>(src));
const float2 tanh_in = __half22float2(__hmul2(
......
......@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() {
"reduce_sum_like",
"layer_norm_grad",
"layer_norm",
"layer_norm_param_grad"
};
return black_list;
}
"layer_norm_param_grad",
const AMPList& AutoMixedPrecisionLists::GrayList() {
static AMPList gray_list = {"add_n",
"add_n",
"tf_avg_pool_1d",
"tf_avg_pool_1d_grad",
"tf_avg_pool_2d",
"tf_avg_pool_2d_grad",
"tf_avg_pool_3d",
"tf_avg_pool_3d_grad",
"avg_pool_1d",
"avg_pool_1d_grad",
"avg_pool_2d",
"avg_pool_2d_grad",
"avg_pool_3d",
"avg_pool_3d_grad",
// "avg_pool_1d",
// "avg_pool_1d_grad",
// "avg_pool_2d",
// "avg_pool_2d_grad",
// "avg_pool_3d",
// "avg_pool_3d_grad",
"bias_add",
"sigmoid_grad",
"tanh",
......@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"group_norm_grad",
"silu",
"silu_grad",
"fused_weighted_sum"};
return gray_list;
}
"fused_weighted_sum",
const AMPList& AutoMixedPrecisionLists::ClearList() {
// TODO(niuchong): tuple_identity
static AMPList clear_list = {"broadcast_like",
"broadcast_like",
"gather",
"gather_nd",
"scatter_nd",
......@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"tf_max_pool_2d_grad",
"tf_max_pool_3d",
"tf_max_pool_3d_grad",
"max_pool_1d",
"max_pool_1d_grad",
"max_pool_2d",
"max_pool_2d_grad",
"max_pool_3d",
"max_pool_3d_grad",
// "max_pool_1d",
// "max_pool_1d_grad",
// "max_pool_2d",
// "max_pool_2d_grad",
// "max_pool_3d",
// "max_pool_3d_grad",
"reshape",
"reshape_like",
"relu",
......@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"pinned_identity",
"to_contiguous",
"copy",
"upsample_nearest_2d"};
"upsample_nearest_2d"
};
return black_list;
}
const AMPList& AutoMixedPrecisionLists::GrayList() {
static AMPList gray_list = {
// "add_n",
// "tf_avg_pool_1d",
// "tf_avg_pool_1d_grad",
// "tf_avg_pool_2d",
// "tf_avg_pool_2d_grad",
// "tf_avg_pool_3d",
// "tf_avg_pool_3d_grad",
"avg_pool_1d",
"avg_pool_1d_grad",
"avg_pool_2d",
"avg_pool_2d_grad",
"avg_pool_3d",
"avg_pool_3d_grad"
// "bias_add",
// "sigmoid_grad",
// "tanh",
// "tanh_grad",
// "sqrt",
// "sqrt_grad",
// "scalar_mul",
// "scalar_mul_by_tensor",
// "scalar_add",
// "scalar_div",
// "scalar_pow",
// "broadcast_add",
// "broadcast_sub",
// "broadcast_mul",
// "broadcast_div",
// "rms_norm",
// "rms_norm_grad",
// "rms_norm_param_grad",
// "dropout",
// "dropout_grad",
// "softmax",
// "softmax_grad",
// "log_softmax",
// "log_softmax_grad",
// "gelu",
// "gelu_grad",
// "fast_gelu",
// "fast_gelu_grad",
// "normalization",
// "normalization_grad",
// "normalization_add_relu",
// "normalization_add_relu_grad",
// "sparse_softmax_cross_entropy",
// "sparse_softmax_cross_entropy_grad",
// "nll",
// "nll_grad",
// "fused_tril_scale_softmax_mask_scale",
// "fused_tril_scale_softmax_mask_scale_grad",
// "fused_scale_mask_softmax_dropout",
// "fused_scale_mask_softmax_dropout_grad",
// "fused_scale_mask_softmax",
// "fused_scale_mask_softmax_grad",
// "fused_bias_add_scale_mask_softmax_dropout",
// "fused_bias_add_gelu",
// "fused_bias_add_gelu_grad",
// "fused_bias_add_mask_scale",
// "fused_fast_gelu_mul",
// "fused_fast_gelu_mul_grad",
// "acc",
// "reciprocal",
// "reciprocal_no_nan",
// "group_norm",
// "group_norm_param_grad",
// "group_norm_grad",
// "silu",
// "silu_grad",
// "fused_weighted_sum"
};
return gray_list;
}
const AMPList& AutoMixedPrecisionLists::ClearList() {
// TODO(niuchong): tuple_identity
static AMPList clear_list = {
// "broadcast_like",
// "gather",
// "gather_nd",
// "scatter_nd",
// "scatter_nd_like",
// "unsorted_segment_sum_like",
// "tf_max_pool_1d",
// "tf_max_pool_1d_grad",
// "tf_max_pool_2d",
// "tf_max_pool_2d_grad",
// "tf_max_pool_3d",
// "tf_max_pool_3d_grad",
"max_pool_1d",
"max_pool_1d_grad",
"max_pool_2d",
"max_pool_2d_grad",
"max_pool_3d",
"max_pool_3d_grad"
// "reshape",
// "reshape_like",
// "relu",
// "relu_grad",
// "transpose",
// "random_mask_like",
// "concat",
// "split_like",
// "pad",
// "same_padding",
// "same_padding_grad",
// "tril",
// "slice",
// "slice_grad",
// "fused_scale_tril",
// "identity",
// "squeeze",
// "embedding",
// "embedding_grad",
// "expand",
// "expand_dims",
// "cast_to_static_shape",
// "parallel_cast",
// "hierarchical_parallel_cast",
// "hierarchical_parallel_cast_like",
// "repeat",
// "unpack",
// "pack",
// "nvtx_start",
// "nvtx_end",
// "narrow",
// "narrow_grad",
// "ones_like",
// "pinned_identity",
// "to_contiguous",
// "copy",
// "upsample_nearest_2d"
};
return clear_list;
}
......
......@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> {
OF_DEVICE_FUNC FusedFastGeluMulFunctor() {}
OF_DEVICE_FUNC half operator()(const half x, const half m) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const float tanh_in =
__half2float(__float2half_rn(alpha) * (x + __float2half_rn(beta) * x * x * x));
const float tanh_out = TanhApprox(tanh_in);
......@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* y, const half* x, const half* m) const {
const half2 x2 = *(reinterpret_cast<const half2*>(x));
const float2 tanh_in = __half22float2(
......@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> {
__device__ void operator()(half& x_diff, half& m_diff, const half& dy, const half& x,
const half& m) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const half halpha = __float2half_rn(alpha);
const half hbeta = __float2half_rn(beta);
const half hone = __float2half_rn(1.0F);
......@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* x_diff, half* m_diff, const half* dy, const half* x,
const half* m) const {
const half2 dy2 = *(reinterpret_cast<const half2*>(dy));
......
......@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
constexpr int kReduceBlockSize = 512;
constexpr int kBlockSize = 128;
#ifdef WITH_ROCM
constexpr int kNumWaves = 64;
#else
constexpr int kNumWaves = 32;
#endif
inline GPU(Error_t) GetReduceNumBlocks(int64_t n, int* num_blocks) {
int dev;
......
......@@ -27,9 +27,12 @@ limitations under the License.
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
#ifdef WITH_CUDA
#include <thrust/pair.h>
#ifdef WITH_ROCM
#include <hipcub/hipcub.hpp>
#else
#include <cub/cub.cuh>
#endif
namespace oneflow {
namespace {
......@@ -143,18 +146,31 @@ __inline__ __device__ T WarpReduce(T val) {
return val;
}
constexpr int tile_size = 32;
constexpr int num_per_block = 4;
#ifdef WITH_ROCM
constexpr int tile_size = 64;
constexpr int block_dim_x = 64;
constexpr int block_dim_y = 64 / num_per_block;
#else
constexpr int tile_size = 32;
constexpr int block_dim_x = 32;
constexpr int block_dim_y = 32 / num_per_block;
#endif
template<typename T, typename ComputeType>
__global__ void LayerNormParamGrad(int rows, int cols, const T* __restrict__ dy,
const T* __restrict__ x, const ComputeType* __restrict__ mean,
const ComputeType* __restrict__ inv_var,
T* __restrict__ tmp_gamma_diff, T* __restrict__ tmp_beta_diff) {
#ifdef WITH_ROCM
__shared__ ComputeType dgamma[64][65];
__shared__ ComputeType dbeta[64][65];
#else
__shared__ ComputeType dgamma[32][33];
__shared__ ComputeType dbeta[32][33];
#endif
ComputeType dgamma_sum[num_per_block];
ComputeType dbeta_sum[num_per_block];
#pragma unroll
......@@ -210,13 +226,13 @@ int GetGirdDimY(const int64_t num_instances, const int64_t norm_size) {
const int max_grid_dim_y = (num_instances + tile_size - 1) / tile_size;
const int block_size = block_dim_x * block_dim_y;
int max_active_blocks = 0;
OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
OF_CUDA_CHECK(GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks, LayerNormParamGrad<T, ComputeType>, block_size, 0));
int waves = 1;
int dev;
OF_CUDA_CHECK(cudaGetDevice(&dev));
OF_CUDA_CHECK(GPU(GetDevice)(&dev));
int sm_count;
OF_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev));
OF_CUDA_CHECK(GPU(DeviceGetAttribute)(&sm_count, GPUMultiProcessorCount, dev));
int num_blocks = max_active_blocks * sm_count * waves;
int grid_dim_y = std::min(max_grid_dim_y, static_cast<int>(num_blocks / grid_dim_x));
return std::max(grid_dim_y, 1);
......@@ -420,6 +436,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
const int64_t norm_size = x->shape_view().elem_cnt() / num_instances;
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const DataType data_type = dy->data_type();
const int grid_dim_x = (norm_size + tile_size - 1) / tile_size;
const int grid_dim_y = GetGirdDimY<T>(num_instances, norm_size);
const size_t tmp_gamma_diff_size = grid_dim_y * norm_size * sizeof(T);
......@@ -428,7 +445,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
T* reduce_buf_ptr =
reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + 2 * tmp_gamma_diff_size);
using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;
LayerNormParamGrad<T, ComputeType><<<dim3(grid_dim_x, grid_dim_y), dim3(32, 32 / num_per_block),
LayerNormParamGrad<T, ComputeType><<<dim3(grid_dim_x, grid_dim_y), dim3(block_dim_x, block_dim_y),
0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(
num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(),
inv_variance->dptr<ComputeType>(), tmp_gamma_diff_ptr, tmp_beta_diff_ptr);
......@@ -476,651 +493,13 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
});
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(float)
#ifdef WITH_CUDA
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(double)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(half)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(nv_bfloat16)
#endif
} // namespace oneflow
#endif
#ifdef WITH_ROCM
#include <hipcub/hipcub.hpp>
#include <thrust/pair.h>
template <typename T, bool is_cuda>
struct AccumulateType { };
#if defined(__HIPCC__)
template <> struct AccumulateType<half, true> { using type = float; };
#endif
template <> struct AccumulateType<float, true> { using type = float; };
template <> struct AccumulateType<double, true> { using type = double; };
template <> struct AccumulateType<int8_t, true> { using type = int64_t; };
template <> struct AccumulateType<uint8_t, true> { using type = int64_t; };
template <> struct AccumulateType<char, true> { using type = int64_t; };
template <> struct AccumulateType<int16_t, true> { using type = int64_t; };
template <> struct AccumulateType<int32_t, true> { using type = int64_t; };
template <> struct AccumulateType<int64_t, true> { using type = int64_t; };
template <> struct AccumulateType<bool, true> {using type = bool; };
template <> struct AccumulateType<float, false> { using type = double; };
template <> struct AccumulateType<double, false> { using type = double; };
template <> struct AccumulateType<int8_t, false> { using type = int64_t; };
template <> struct AccumulateType<uint8_t, false> { using type = int64_t; };
template <> struct AccumulateType<char, false> { using type = int64_t; };
template <> struct AccumulateType<int16_t, false> { using type = int64_t; };
template <> struct AccumulateType<int32_t, false> { using type = int64_t; };
template <> struct AccumulateType<int64_t, false> { using type = int64_t; };
template <> struct AccumulateType<bool, false> {using type = bool; };
template<typename T, bool is_cuda>
using acc_type = typename AccumulateType<T, is_cuda>::type;
#define C10_HOST_DEVICE __host__ __device__
#define C10_DEVICE __device__
#define C10_HOST __host__
#define C10_WARP_SIZE 64
#define VEC 4
typedef int64_t IndexType ;
constexpr int BlockReduceNumThreads=512;
constexpr int NumThreads = 256;
constexpr int ColwiseReduceTileSize = 32;
template <typename scalar_t, typename index_t, typename combine_t>
struct WelfordData {
scalar_t mean;
scalar_t m2;
index_t n;
combine_t nf;
C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
C10_HOST_DEVICE WelfordData(
scalar_t mean,
scalar_t m2,
index_t n,
combine_t nf)
: mean(mean), m2(m2), n(n), nf(nf) {}
};
template <typename scalar_t, typename acc_scalar_t, typename index_t, typename combine_t, typename res_t>
struct WelfordOps {
public:
using acc_t = WelfordData<acc_scalar_t, index_t, combine_t>;
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data) const {
acc_scalar_t delta = data - acc.mean;
// using acc.nf(combine_t) here, as acc.n(index_t) would still be converted
// accumulation in reduce is done through index_T
acc_scalar_t new_mean = acc.mean + delta / (acc.nf + 1);
acc_scalar_t new_delta = data - new_mean;
return {
new_mean,
acc.m2 + delta * new_delta,
acc.n + 1,
combine_t(acc.n + 1), // accumulate for combine_t uses index_t
};
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
if (a.nf == 0) {
return b;
}
if (b.nf == 0) {
return a;
}
acc_scalar_t delta = b.mean - a.mean;
combine_t new_count = a.nf + b.nf;
acc_scalar_t nb_over_n = b.nf / new_count;
return {
a.mean + delta * nb_over_n,
a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
// setting acc.n as -1 since acc.n might not be able to represent the count
// correctly within its range, setting it to -1 to avoid confusion
-1,
new_count
};
}
inline C10_DEVICE res_t project(acc_t acc) const {
return res_t(acc.m2 / acc.nf, static_cast<scalar_t>(acc.mean));
}
inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
return {
__shfl_down(acc.mean, offset)
, __shfl_down(acc.m2, offset)
, __shfl_down(acc.n, offset)
, __shfl_down(acc.nf, offset)
};
}
};
template <int max=32,typename T, class ReduceOp>
__inline__ __device__ T WarpReduce(T val,const ReduceOp& op) {
#pragma unroll
for (int offset = max; offset > 0; offset >>= 1) {
val = op.combine(val, op.warp_shfl_down(val, offset));
}
return val;
}
template <typename T, class ReduceOp>
__inline__ __device__ T
BlockReduce(T val, const ReduceOp& op, T* shared) {
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
val = WarpReduce(val, op);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0) {
val= shared[lid];
val = WarpReduce<4>(val,op);
}
return val;
}
template <int max=32,typename T>
__inline__ __device__ T WarpReduceSum(T val) {
#pragma unroll
for (int offset = max; offset > 0; offset >>= 1) {
val += __shfl_down(val, offset);
}
return val;
}
template <typename T>
__inline__ __device__ T BlockReduceSum(T val, T* shared) {
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
val = WarpReduceSum(val);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0) {
val= shared[lid];
val = WarpReduceSum<4>(val);
}
return val;
}
template <typename scalar_t>
__global__ void layernorm_forward_kernel(const scalar_t* input,scalar_t* ret,acc_type<scalar_t, true>* mean,acc_type<scalar_t, true>* rstd,
const scalar_t* gamma,const scalar_t* beta,IndexType cols,double eps)
{
//dropout do nothing in val mode
IndexType i=blockIdx.x;
// add + layernorm get mean and rstd
using T_ACC = acc_type<scalar_t, true>;
using WelfordType = WelfordData<T_ACC, IndexType, T_ACC>;
using WelfordOp = WelfordOps<T_ACC, T_ACC, IndexType, T_ACC, thrust::pair<T_ACC, T_ACC>>;
__shared__ typename std::aligned_storage<sizeof(WelfordType), alignof(WelfordType)>::type val_shared[BlockReduceNumThreads/C10_WARP_SIZE];
WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared);
WelfordOp welford_op;
WelfordType val;
#pragma unroll
for (IndexType j = threadIdx.x; j < cols; j += blockDim.x) {
IndexType index = i * cols + j;
val = welford_op.reduce(val, static_cast<T_ACC>(input[index]));
}
val = BlockReduce(val,welford_op,val_shared_ptr);
__shared__ T_ACC s_mean;
__shared__ T_ACC s_rstd;
if (threadIdx.x == 0) {
thrust::tie(s_rstd, s_mean) = welford_op.project(val);
mean[i] = s_mean;
s_rstd=rsqrt(s_rstd + static_cast<T_ACC>(eps));
rstd[i] = s_rstd;
}
__syncthreads();
//layernorm (x-mean)*rstd*gamma+beta
#pragma unroll
for (IndexType j = threadIdx.x; j < cols; j += blockDim.x) {
IndexType index = i * cols + j;
ret[index] = static_cast<scalar_t>((static_cast<T_ACC>(input[index]) - s_mean)*s_rstd * (gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[j]))
+ (beta == nullptr ? T_ACC(0) : static_cast<T_ACC>(beta[j])));
}
}
template <typename T>
void LayerNormKernelImplInternal(
oneflow::ep::Stream* stream,
const T* X,
const T* gamma,
const T* beta,
int64_t M,
int64_t N,
double eps,
T* Y,
acc_type<T, true>* mean,
acc_type<T, true>* rstd) {
using T_ACC = acc_type<T, true>;
const T* X_data = X;
const T* gamma_data = gamma;
const T* beta_data = beta;
T* Y_data = Y;
T_ACC* mean_data = mean;
T_ACC* rstd_data = rstd;
hipStream_t cuda_stream = stream->As<oneflow::ep::CudaStream>()->cuda_stream();
layernorm_forward_kernel<T><<<M, BlockReduceNumThreads, 0, cuda_stream>>>(
X_data,Y_data,mean_data,rstd_data,gamma_data,beta_data,N,eps);
}
template <typename scalar_t>
__global__ void GammaBetaBackwardSimple(IndexType M,IndexType N,const scalar_t* dY,const scalar_t* X,const acc_type<scalar_t, true>* mean,
const acc_type<scalar_t, true>* rstd,scalar_t* dg,scalar_t* db)
{
using T_ACC = acc_type<scalar_t, true>;
const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
if (j < N) {
T_ACC sum1 = 0;
T_ACC sum2 = 0;
for (int64_t i = 0; i < M; ++i) {
const int64_t index = i * N + j;
sum1 += dg == nullptr ? T_ACC(0)
: static_cast<T_ACC>(dY[index]) *
(static_cast<T_ACC>(X[index]) - static_cast<T_ACC>(mean[i])) *
static_cast<T_ACC>(rstd[i]);
sum2 += db == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index]);
}
if (dg != nullptr) {
dg[j] = static_cast<scalar_t>(sum1);
}
if (db != nullptr) {
db[j] = static_cast<scalar_t>(sum2);
}
}
}
template <typename scalar_t>
__global__ void GammaBetaBackward(IndexType M,IndexType N,const scalar_t* dY,const scalar_t* X,const acc_type<scalar_t, true>* mean,
const acc_type<scalar_t, true>* rstd,scalar_t* dg,scalar_t* db)
{
using T_ACC = acc_type<scalar_t, true>;
__shared__ T_ACC g_shared[ColwiseReduceTileSize][ColwiseReduceTileSize + 1];
__shared__ T_ACC b_shared[ColwiseReduceTileSize][ColwiseReduceTileSize + 1];
const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
T_ACC dg_sum1 = 0;
T_ACC dg_sum2 = 0;
T_ACC db_sum1 = 0;
T_ACC db_sum2 = 0;
if (j < N) {
for (int64_t i = threadIdx.y; i < M; i += blockDim.y * 2) {
const int64_t i1 = i;
const int64_t i2 = i + blockDim.y;
const int64_t index1 = i1 * N + j;
const int64_t index2 = i2 * N + j;
dg_sum1 += dg == nullptr ? T_ACC(0)
: static_cast<T_ACC>(dY[index1]) *
(static_cast<T_ACC>(X[index1]) - static_cast<T_ACC>(mean[i1])) *
static_cast<T_ACC>(rstd[i1]);
db_sum1 += db == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index1]);
if (i2 < M) {
dg_sum2 += dg == nullptr ? T_ACC(0)
: static_cast<T_ACC>(dY[index2]) *
(static_cast<T_ACC>(X[index2]) - static_cast<T_ACC>(mean[i2])) *
static_cast<T_ACC>(rstd[i2]);
db_sum2 += db == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index2]);
}
}
}
g_shared[threadIdx.y][threadIdx.x] = dg_sum1;
g_shared[threadIdx.y + blockDim.y][threadIdx.x] = dg_sum2;
b_shared[threadIdx.y][threadIdx.x] = db_sum1;
b_shared[threadIdx.y + blockDim.y][threadIdx.x] = db_sum2;
__syncthreads();
T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y];
T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y];
sum1 = WarpReduceSum<16>(sum1);
sum2 = WarpReduceSum<16>(sum2);
if (threadIdx.x == 0) {
const int64_t j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) {
if (dg != nullptr) {
dg[j] = static_cast<scalar_t>(sum1);
}
if (db != nullptr) {
db[j] = static_cast<scalar_t>(sum2);
}
}
}
sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y];
sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y];
sum1 = WarpReduceSum<16>(sum1);
sum2 = WarpReduceSum<16>(sum2);
if (threadIdx.x == 0) {
const int64_t j = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y;
if (j < N) {
if (dg != nullptr) {
dg[j] = static_cast<scalar_t>(sum1);
}
if (db != nullptr) {
db[j] = static_cast<scalar_t>(sum2);
}
}
}
}
template <typename scalar_t>
__global__ void LayerNormBackward_kernel(IndexType N,const scalar_t* dY,const scalar_t* X,const scalar_t* gamma,const acc_type<scalar_t, true>* mean,
const acc_type<scalar_t, true>* rstd, scalar_t* dX, const scalar_t* add_to_output)
{
using T_ACC = acc_type<scalar_t, true>;
__shared__ T_ACC ds_shared[C10_WARP_SIZE];
__shared__ T_ACC db_shared[C10_WARP_SIZE];
const IndexType i = blockIdx.x;
T_ACC sum1 = 0;
T_ACC sum2 = 0;
#pragma unroll
for (IndexType j = threadIdx.x; j < N; j += blockDim.x) {
const IndexType index = i * N + j;
const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[j]);
sum1 += static_cast<T_ACC>(dY[index]) * static_cast<T_ACC>(X[index]) * gamma_v;
sum2 += static_cast<T_ACC>(dY[index]) * gamma_v;
}
sum1 = BlockReduceSum<T_ACC>(sum1, ds_shared);
sum2 = BlockReduceSum<T_ACC>(sum2, db_shared);
const T_ACC s = T_ACC(1) / static_cast<T_ACC>(N);
__shared__ T_ACC b;
__shared__ T_ACC c;
if (threadIdx.x == 0) {
b = (sum2 * static_cast<T_ACC>(mean[i]) - sum1) * static_cast<T_ACC>(rstd[i]) * static_cast<T_ACC>(rstd[i]) *static_cast<T_ACC>(rstd[i]) * s;
c = -(b * static_cast<T_ACC>(mean[i]) + sum2 * static_cast<T_ACC>(rstd[i]) * s);
}
__syncthreads();
#pragma unroll
for (IndexType j = threadIdx.x; j < N; j += blockDim.x) {
const IndexType index = i * N + j;
const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast<T_ACC>(gamma[j]);
dX[index] = static_cast<scalar_t>(static_cast<T_ACC>(rstd[i]) * static_cast<T_ACC>(dY[index]) * gamma_v + b * static_cast<T_ACC>(X[index]) + c
+ (add_to_output == nullptr ? T_ACC(0) : static_cast<T_ACC>(add_to_output[index])));
}
}
template <typename T>
void LayerNormBackwardKernelImplInternal(
oneflow::ep::Stream* stream,
const T* dY,
const T* X,
const acc_type<T, true>* mean,
const acc_type<T, true>* rstd,
const T* gamma,
int64_t M,
int64_t N,
T* dX,
const T* add_to_output) {
using T_ACC = acc_type<T, true>;
const T* dY_data = dY;
const T* X_data = X;
const T_ACC* mean_data = mean;
const T_ACC* rstd_data = rstd;
const T* gamma_data = gamma;
T* dX_data = dX;
const T* add_to_output_data = add_to_output;
hipStream_t cuda_stream = stream->As<oneflow::ep::CudaStream>()->cuda_stream();
if (dX_data != nullptr) {
LayerNormBackward_kernel<T><<<M, BlockReduceNumThreads, 0, cuda_stream>>>(
N, dY_data, X_data,gamma_data,mean_data,rstd_data,dX_data,add_to_output_data);
}
}
template <typename T>
void LayerNormBackwardKernelImplInternalParam(
oneflow::ep::Stream* stream,
const T* dY,
const T* X,
const acc_type<T, true>* mean,
const acc_type<T, true>* rstd,
int64_t M,
int64_t N,
T* dgamma,
T* dbeta) {
using T_ACC = acc_type<T, true>;
const T* dY_data = dY;
const T* X_data = X;
const T_ACC* mean_data = mean;
const T_ACC* rstd_data = rstd;
hipStream_t cuda_stream = stream->As<oneflow::ep::CudaStream>()->cuda_stream();
T* dgamma_data = dgamma;
T* dbeta_data = dbeta;
if (M < 512) {
// For small batch size, do colwise reduce directly.
const int64_t B = (N + NumThreads - 1) / NumThreads;
GammaBetaBackwardSimple<T>
<<<B, NumThreads, 0, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
} else {
const int64_t B =
(N + ColwiseReduceTileSize - 1) / ColwiseReduceTileSize;
constexpr int kThreadX = ColwiseReduceTileSize;
constexpr int kThreadY = ColwiseReduceTileSize / 2;
GammaBetaBackward<T>
<<<B, dim3(kThreadX, kThreadY), 0, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
}
}
namespace oneflow {
template<typename T>
class LayerNormGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {
public:
LayerNormGpuKernel() = default;
~LayerNormGpuKernel() = default;
private:
using user_op::OpKernel::Compute;
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0);
double epsilon = ctx->Attr<double>("epsilon");
int64_t num_instances = mean->shape_view().elem_cnt();
int64_t norm_size = x->shape_view().elem_cnt() / num_instances;
const T* gamma_ptr = nullptr;
const T* beta_ptr = nullptr;
if (ctx->has_input("gamma", 0)) {
const user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0);
gamma_ptr = gamma->dptr<T>();
CHECK_EQ(gamma->shape_view().elem_cnt(), norm_size);
}
if (ctx->has_input("beta", 0)) { beta_ptr = ctx->Tensor4ArgNameAndIndex("beta", 0)->dptr<T>(); }
// DispatchLayerNormForwardGpu<T>(ctx->stream(), num_instances, norm_size, epsilon, x->dptr<T>(),
// gamma_ptr, beta_ptr, y->mut_dptr<T>(), mean, inv_variance);
using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;
LayerNormKernelImplInternal<T>(ctx->stream(), x->dptr<T>(), gamma_ptr, beta_ptr, num_instances, norm_size, epsilon,
y->mut_dptr<T>(), mean->mut_dptr<ComputeType>(), inv_variance->mut_dptr<ComputeType>());
};
};
#define REGISTER_LAYER_NORM_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm") \
.SetCreateFn<LayerNormGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value));
REGISTER_LAYER_NORM_CUDA_KERNEL(float)
REGISTER_LAYER_NORM_CUDA_KERNEL(double)
REGISTER_LAYER_NORM_CUDA_KERNEL(half)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_CUDA_KERNEL(nv_bfloat16)
#endif
template<typename T>
class LayerNormGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {
public:
LayerNormGradGpuKernel() = default;
~LayerNormGradGpuKernel() = default;
private:
using user_op::OpKernel::Compute;
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0);
user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
int64_t num_instances = mean->shape_view().elem_cnt();
int64_t norm_size = x->shape_view().elem_cnt() / num_instances;
const T* gamma_ptr = nullptr;
if (ctx->has_input("gamma", 0)) {
gamma_ptr = ctx->Tensor4ArgNameAndIndex("gamma", 0)->dptr<T>();
}
const T* add_to_output_ptr = nullptr;
if (ctx->has_input("_add_to_output", 0)) {
const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0);
CHECK_EQ(add_to_output->data_type(), dx->data_type());
CHECK_EQ(add_to_output->shape_view(), dx->shape_view());
add_to_output_ptr = add_to_output->dptr<T>();
}
// LaunchLayerNormBackward<T>(ctx->stream(), num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(),
// mean, inv_variance, gamma_ptr, add_to_output_ptr, dx->mut_dptr<T>());
using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;
LayerNormBackwardKernelImplInternal<T>(ctx->stream(), dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(), inv_variance->dptr<ComputeType>(),
gamma_ptr, num_instances, norm_size, dx->mut_dptr<T>(), add_to_output_ptr);
};
};
#define REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm_grad") \
.SetCreateFn<LayerNormGradGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value)) \
.SetInplaceProposalFn( \
[](const user_op::InferContext& ctx, \
const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> { \
if (ctx.has_input("_add_to_output", 0)) { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "_add_to_output", 0, true)); \
} \
return Maybe<void>::Ok(); \
});
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(float)
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(double)
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(half)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
#endif
template<typename T>
class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
public user_op::CudaGraphSupport {
public:
LayerNormParamGradGpuKernel() = default;
~LayerNormParamGradGpuKernel() = default;
private:
using user_op::OpKernel::Compute;
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
const user_op::Tensor* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
const user_op::Tensor* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0);
int64_t num_instances = mean->shape_view().elem_cnt();
int64_t norm_size = x->shape_view().elem_cnt() / num_instances;
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
// const DataType data_type = dy->data_type();
// const int grid_dim_x = (norm_size + tile_size - 1) / tile_size;
// const int grid_dim_y = GetGirdDimY<T>(num_instances, norm_size);
// const size_t tmp_gamma_diff_size = grid_dim_y * norm_size * sizeof(T);
// T* tmp_gamma_diff_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr());
// T* tmp_beta_diff_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + tmp_gamma_diff_size);
// T* reduce_buf_ptr =
// reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + 2 * tmp_gamma_diff_size);
using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;
// LayerNormParamGrad<T, ComputeType><<<dim3(grid_dim_x, grid_dim_y), dim3(32, 32 / num_per_block),
// 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(
// num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(),
// inv_variance->dptr<ComputeType>(), tmp_gamma_diff_ptr, tmp_beta_diff_ptr);
// const int32_t m = norm_size;
// const int32_t n = 1;
// const int32_t k = grid_dim_y;
// std::unique_ptr<ep::primitive::Fill> fill =
// ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),
// data_type);
// CHECK(fill);
// fill->Launch(ctx->stream(), reduce_buf_ptr, 1.0, grid_dim_y);
// std::unique_ptr<ep::primitive::Matmul> matmul =
// ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(
// ctx->stream()->device_type(), data_type, ep::primitive::BlasTransposeType::T,
// ep::primitive::BlasTransposeType::N);
// CHECK(matmul);
// if (ctx->has_output("gamma_diff", 0)) {
// user_op::Tensor* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0);
// matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_gamma_diff_ptr, reduce_buf_ptr, 0.0,
// gamma_diff->mut_dptr());
// }
// if (ctx->has_output("beta_diff", 0)) {
// user_op::Tensor* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0);
// matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_beta_diff_ptr, reduce_buf_ptr, 0.0,
// beta_diff->mut_dptr());
// }
T* gamma_diff_ptr = nullptr;
T* beta_diff_ptr = nullptr;
if (ctx->has_output("gamma_diff", 0)) {
gamma_diff_ptr = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0)->mut_dptr<T>();
}
if (ctx->has_output("beta_diff", 0)) {
beta_diff_ptr = ctx->Tensor4ArgNameAndIndex("beta_diff", 0)->mut_dptr<T>();
}
LayerNormBackwardKernelImplInternalParam<T>(ctx->stream(), dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(), inv_variance->dptr<ComputeType>(),
num_instances, norm_size, gamma_diff_ptr, beta_diff_ptr);
};
};
#define REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm_param_grad") \
.SetCreateFn<LayerNormParamGradGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) { \
const int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis"); \
const bool has_gamma_diff = ctx->has_output("gamma_diff", 0); \
const bool has_beta_diff = ctx->has_output("beta_diff", 0); \
const auto& dy = ctx->InputTensorDesc("dy", 0); \
const int64_t num_instances = dy.shape().Count(0, begin_params_axis); \
const int64_t norm_size = dy.shape().Count(begin_params_axis); \
const int grid_dim_y = num_instances; \
size_t tmp_buffer_size = (2 * grid_dim_y * norm_size + grid_dim_y) * sizeof(dtype); \
return tmp_buffer_size; \
});
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(float)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(double)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(half)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(nv_bfloat16)
#endif
}
} // namespace oneflow
#endif
......@@ -28,7 +28,7 @@ limitations under the License.
#elif defined(__HIPCC__)
#include <hip/hsa_detail/math_functions.h>
#include <hip/math_functions.h>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
......
......@@ -882,6 +882,22 @@ namespace oneflow {
namespace {
template<typename T>
void printTensor(const std::string& str, const T* devTensor, size_t size) {
T* hostTensor;
hostTensor = (T*)malloc(size * sizeof(T));
hipMemcpy(hostTensor, devTensor, size * sizeof(T), hipMemcpyDeviceToHost);
std::cout << str << ": ";
for(int i; i<size; i++) {
if (i % 16 == 0) {
std::cout << std::endl;
}
std::cout << hostTensor[i] << ", ";
}
std::cout << str << ": finish" << std::endl;
free(hostTensor);
}
hipdnnBatchNormMode_t getCudnnBatchNormMode(const int64_t dim) {
if (dim == 2) {
return HIPDNN_BATCHNORM_PER_ACTIVATION;
......@@ -969,6 +985,15 @@ class CudnnTensorDescHelper final {
int32_t param_size_ = 0;
};
size_t InferInferTmpSize(user_op::InferContext* ctx) {
const auto& y = ctx->OutputTensorDesc("y", 0);
if (ctx->has_input("_add_to_output", 0)) {
return y.shape().elem_cnt() * GetSizeOfDataType(y.data_type());
} else {
return 1;
}
}
size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_type,
const int32_t axis) {
return 1;
......@@ -976,8 +1001,13 @@ size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_typ
size_t InferTrainTmpSize(user_op::InferContext* ctx) {
const auto& x = ctx->InputTensorDesc("x", 0);
const auto& y = ctx->OutputTensorDesc("y", 0);
const auto axis = ctx->Attr<int32_t>("axis");
if (ctx->has_input("_add_to_output", 0)) {
return y.shape().elem_cnt() * GetSizeOfDataType(y.data_type());
} else {
return InferTrainWorkspaceSize(x.shape(), x.data_type(), axis);
}
}
size_t InferGradWorkspaceSize(const ShapeView& x_shape, const DataType data_type,
......@@ -1016,6 +1046,9 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
const auto axis = ctx->Attr<int32_t>("axis");
const auto epsilon = ctx->Attr<float>("epsilon");
auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
void* add_to_output_dev = tmp_buffer->mut_dptr<void>();
const DataType data_type = x->data_type();
CHECK_EQ(x->shape_view(), y->shape_view());
CHECK_EQ(y->data_type(), data_type);
......@@ -1030,17 +1063,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
desc_helper.CheckParamTensor(moving_variance);
const void* sp_alpha = CudnnSPOnePtr(data_type);
const void* sp_beta;
const void* sp_beta = CudnnSPZeroPtr(data_type);
if (ctx->has_input("_add_to_output", 0)) {
const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0);
CHECK_EQ(add_to_output->data_type(), y->data_type());
CHECK_EQ(add_to_output->shape_view(), y->shape_view());
Memcpy<DeviceType::kCUDA>(
ctx->stream(), y->mut_dptr<void>(), add_to_output->dptr<void>(),
ctx->stream(), add_to_output_dev, add_to_output->dptr<void>(),
add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));
sp_beta = CudnnSPOnePtr(data_type);
} else {
sp_beta = CudnnSPZeroPtr(data_type);
}
OF_CUDNN_CHECK(hipdnnBatchNormalizationForwardInference(
......@@ -1048,6 +1079,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(),
desc_helper.param_desc(), gamma->dptr(), beta->dptr(), moving_mean->dptr(),
moving_variance->dptr(), epsilon));
if (ctx->has_input("_add_to_output", 0)) {
sp_beta = CudnnSPOnePtr(data_type);
OF_CUDNN_CHECK(hipdnnAddTensor(ctx->stream()->As<ep::CudaStream>()->cudnn_handle(),
sp_alpha, desc_helper.xy_desc(),
add_to_output_dev, sp_beta, desc_helper.xy_desc(),
y->mut_dptr()));
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
......@@ -1057,6 +1097,7 @@ REGISTER_USER_KERNEL("normalization")
.SetCreateFn<NormalizationInferenceKernel>()
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)
&& (user_op::HobAttr<bool>("training") == false))
.SetInferTmpSizeFn(InferInferTmpSize)
.SetInplaceProposalFn([](const user_op::InferContext& ctx,
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {
if (ctx.has_input("_add_to_output", 0)) {
......@@ -1068,76 +1109,78 @@ REGISTER_USER_KERNEL("normalization")
constexpr int64_t kCudaWarpSize = 64;
template<typename T>
__global__ void ReluGpu(int64_t n, const T* x, T* y, int32_t* mask) {
__global__ void ReluGpu(int64_t n, const T* x, T* y, int64_t* mask) {
const int32_t lane_id = threadIdx.x % kCudaWarpSize;
const T zero = static_cast<T>(0.f);
CUDA_1D_KERNEL_LOOP(i, n) {
const T x_val = x[i];
const bool is_positive = (x_val > zero);
int32_t warp_mask = __ballot(static_cast<int>(is_positive));
if (lane_id == 0) { mask[i / kCudaWarpSize] = warp_mask; }
unsigned long long int warp_mask_tmp = __ballot(static_cast<int>(is_positive));
int64_t* warp_mask = reinterpret_cast<int64_t*>(&warp_mask_tmp);
if (lane_id == 0) { mask[i / kCudaWarpSize] = *warp_mask; }
y[i] = is_positive ? x_val : zero;
}
}
template<typename T>
__global__ void AddReluGpu(int64_t n, const T* x, const T* addend, T* y, int32_t* mask) {
__global__ void AddReluGpu(int64_t n, const T* x, const T* addend, T* y, int64_t* mask) {
const int32_t lane_id = threadIdx.x % kCudaWarpSize;
const T zero = static_cast<T>(0.f);
CUDA_1D_KERNEL_LOOP(i, n) {
const T sum = x[i] + addend[i];
const bool is_positive = (sum > zero);
int32_t warp_mask = __ballot(static_cast<int>(is_positive));
if (lane_id == 0) { mask[i / kCudaWarpSize] = warp_mask; }
unsigned long long int warp_mask_tmp = __ballot(static_cast<int>(is_positive));
int64_t* warp_mask = reinterpret_cast<int64_t*>(&warp_mask_tmp);
if (lane_id == 0) { mask[i / kCudaWarpSize] = *warp_mask; }
y[i] = is_positive ? sum : zero;
}
}
template<typename T>
void Relu(ep::Stream* stream, int64_t n, const T* x, T* y, int32_t* mask) {
void Relu(ep::Stream* stream, int64_t n, const T* x, T* y, int64_t* mask) {
ReluGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, y, mask);
}
template<typename T>
void AddRelu(ep::Stream* stream, int64_t n, const T* x, const T* addend, T* y, int32_t* mask) {
void AddRelu(ep::Stream* stream, int64_t n, const T* x, const T* addend, T* y, int64_t* mask) {
AddReluGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(n, x, addend, y, mask);
}
template<typename T>
__global__ void ReluBackwardGpu(int64_t n, const int32_t* mask, const T* dy, T* addend_diff) {
int32_t lane_id = threadIdx.x % kCudaWarpSize;
__global__ void ReluBackwardGpu(int64_t n, const int64_t* mask, const T* dy, T* addend_diff) {
int64_t lane_id = threadIdx.x % kCudaWarpSize;
CUDA_1D_KERNEL_LOOP(i, n) {
int32_t mask_val = mask[i / kCudaWarpSize];
bool is_positive = mask_val & (1 << lane_id);
int64_t mask_val = mask[i / kCudaWarpSize];
bool is_positive = mask_val & ((int64_t)1 << lane_id);
addend_diff[i] = static_cast<T>(is_positive) * dy[i];
}
}
#if CUDA_VERSION >= 11000
// #if CUDA_VERSION >= 11000
template<>
__global__ void ReluBackwardGpu<nv_bfloat16>(int64_t n, const int32_t* mask, const nv_bfloat16* dy,
nv_bfloat16* addend_diff) {
int32_t lane_id = threadIdx.x % kCudaWarpSize;
CUDA_1D_KERNEL_LOOP(i, n) {
int32_t mask_val = mask[i / kCudaWarpSize];
bool is_positive = mask_val & (1 << lane_id);
addend_diff[i] = static_cast<nv_bfloat16>(static_cast<float>(is_positive)) * dy[i];
}
}
// template<>
// __global__ void ReluBackwardGpu<nv_bfloat16>(int64_t n, const int32_t* mask, const nv_bfloat16* dy,
// nv_bfloat16* addend_diff) {
// int32_t lane_id = threadIdx.x % kCudaWarpSize;
// CUDA_1D_KERNEL_LOOP(i, n) {
// int32_t mask_val = mask[i / kCudaWarpSize];
// bool is_positive = mask_val & (1 << lane_id);
// addend_diff[i] = static_cast<nv_bfloat16>(static_cast<float>(is_positive)) * dy[i];
// }
// }
#endif
// #endif
template<typename T>
void ReluBackward(ep::Stream* stream, int64_t n, const int32_t* mask, const T* dy, T* addend_diff) {
void ReluBackward(ep::Stream* stream, int64_t n, const int64_t* mask, const T* dy, T* addend_diff) {
ReluBackwardGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(n, mask, dy, addend_diff);
}
void Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x, void* y,
int32_t* mask) {
int64_t* mask) {
if (data_type == kFloat) {
Relu<float>(stream, n, reinterpret_cast<const float*>(x), reinterpret_cast<float*>(y), mask);
} else if (data_type == kDouble) {
......@@ -1156,7 +1199,7 @@ void Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x
}
}
void AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x,
const void* addend, void* y, int32_t* mask) {
const void* addend, void* y, int64_t* mask) {
if (data_type == kFloat) {
AddRelu<float>(stream, n, reinterpret_cast<const float*>(x),
reinterpret_cast<const float*>(addend), reinterpret_cast<float*>(y), mask);
......@@ -1178,7 +1221,7 @@ void AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void
UNIMPLEMENTED();
}
}
void ReluBackward(ep::Stream* stream, int64_t n, const DataType data_type, const int32_t* mask,
void ReluBackward(ep::Stream* stream, int64_t n, const DataType data_type, const int64_t* mask,
const void* dy, void* addend_diff) {
if (data_type == kFloat) {
ReluBackward<float>(stream, n, mask, reinterpret_cast<const float*>(dy),
......@@ -1225,6 +1268,9 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
hipdnnBatchNormMode_t mode = getCudnnBatchNormMode(x->shape_view().NumAxes());
const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis, mode);
auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
void* add_to_output_dev = tmp_buffer->mut_dptr<void>();
const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0);
const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0);
auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
......@@ -1244,17 +1290,15 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
desc_helper.CheckParamTensor(moving_variance);
}
const void* sp_alpha = CudnnSPOnePtr(data_type);
const void* sp_beta;
const void* sp_beta = CudnnSPZeroPtr(data_type);
if (ctx->has_input("_add_to_output", 0)) {
const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0);
CHECK_EQ(add_to_output->data_type(), y->data_type());
CHECK_EQ(add_to_output->shape_view(), y->shape_view());
Memcpy<DeviceType::kCUDA>(
ctx->stream(), y->mut_dptr<void>(), add_to_output->dptr<void>(),
ctx->stream(), add_to_output_dev, add_to_output->dptr<void>(),
add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type()));
sp_beta = CudnnSPOnePtr(data_type);
} else {
sp_beta = CudnnSPZeroPtr(data_type);
}
OF_CUDNN_CHECK(hipdnnBatchNormalizationForwardTraining(
......@@ -1265,6 +1309,14 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
moving_variance ? moving_variance->mut_dptr() : NULL, epsilon, mean->mut_dptr(),
inv_variance->mut_dptr()));
if (ctx->has_input("_add_to_output", 0)) {
sp_beta = CudnnSPOnePtr(data_type);
OF_CUDNN_CHECK(hipdnnAddTensor(ctx->stream()->As<ep::CudaStream>()->cudnn_handle(),
sp_alpha, desc_helper.xy_desc(),
add_to_output_dev, sp_beta, desc_helper.xy_desc(),
y->mut_dptr()));
}
if (ctx->op_type_name() == "normalization_add_relu") {
CHECK(!ctx->has_input("_add_to_output", 0));
const int64_t elem_cnt = x->shape_view().elem_cnt();
......@@ -1272,10 +1324,10 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
if (ctx->has_input("addend", 0)) {
const auto* addend = ctx->Tensor4ArgNameAndIndex("addend", 0);
AddRelu(ctx->stream(), elem_cnt, data_type, y->dptr(), addend->dptr(), y->mut_dptr(),
mask->mut_dptr<int32_t>());
mask->mut_dptr<int64_t>());
} else {
Relu(ctx->stream(), elem_cnt, data_type, y->dptr(), y->mut_dptr(),
mask->mut_dptr<int32_t>());
mask->mut_dptr<int64_t>());
}
}
}
......@@ -1351,7 +1403,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel,
user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
if (ctx->has_output("addend_diff", 0)) {
user_op::Tensor* addend_diff = ctx->Tensor4ArgNameAndIndex("addend_diff", 0);
ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr<int32_t>(), dy->dptr(),
ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr<int64_t>(), dy->dptr(),
addend_diff->mut_dptr());
bn_workspace_ptr = tmp_buffer->mut_dptr();
bn_workspace_size = tmp_buffer->shape_view().elem_cnt();
......@@ -1361,7 +1413,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel,
const size_t relu_dx_size =
GetCudaAlignedSize(dy->shape_view().elem_cnt() * GetSizeOfDataType(data_type));
CHECK_GE(tmp_buffer_size, relu_dx_size);
ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr<int32_t>(), dy->dptr(),
ReluBackward(ctx->stream(), elem_cnt, data_type, mask->dptr<int64_t>(), dy->dptr(),
tmp_buffer->mut_dptr());
bn_workspace_ptr = tmp_buffer->mut_dptr<char>() + relu_dx_size;
bn_workspace_size = tmp_buffer_size - relu_dx_size;
......@@ -1393,231 +1445,6 @@ REGISTER_USER_KERNEL("normalization_add_relu_grad")
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA))
.SetInferTmpSizeFn(InferGradTmpSize);
#if (HIPDNN_VERSION >= 7401)
size_t InferFusedNormalizationAddReluTmpSize(user_op::InferContext* ctx) {
const auto& x = ctx->InputTensorDesc("x", 0);
const auto axis = ctx->Attr<int32_t>("axis");
const CudnnTensorDescHelper desc_helper(x.shape(), x.data_type(), axis,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT);
size_t size_in_bytes;
hipdnnHandle_t handle = Singleton<CudnnHandlePool>::Get()->Get();
CudnnActivationDesc activation_desc(HIPDNN_ACTIVATION_RELU, HIPDNN_PROPAGATE_NAN, 0);
cudnnBatchNormOps_t ops;
hipdnnTensorDescriptor_t z_desc;
if (ctx->has_input("addend", 0)) {
ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
z_desc = desc_helper.xy_desc();
} else {
ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
z_desc = nullptr;
}
OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
handle, HIPDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, desc_helper.xy_desc(), z_desc,
desc_helper.xy_desc(), desc_helper.param_desc(), activation_desc.Get(), &size_in_bytes));
Singleton<CudnnHandlePool>::Get()->Put(handle);
return std::max(size_in_bytes, static_cast<size_t>(1));
}
size_t InferFusedNormalizationAddReluGradTmpSize(user_op::InferContext* ctx) {
const auto& x = ctx->InputTensorDesc("x", 0);
const auto axis = ctx->Attr<int32_t>("axis");
const CudnnTensorDescHelper desc_helper(x.shape(), x.data_type(), axis,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT);
size_t size_in_bytes;
hipdnnHandle_t handle = Singleton<CudnnHandlePool>::Get()->Get();
CudnnActivationDesc activation_desc(HIPDNN_ACTIVATION_RELU, HIPDNN_PROPAGATE_NAN, 0);
cudnnBatchNormOps_t ops;
hipdnnTensorDescriptor_t z_desc;
if (ctx->has_output("addend_diff", 0)) {
ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
z_desc = desc_helper.xy_desc();
} else {
ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
z_desc = nullptr;
}
OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
handle, HIPDNN_BATCHNORM_SPATIAL_PERSISTENT, ops, desc_helper.xy_desc(), desc_helper.xy_desc(),
desc_helper.xy_desc(), z_desc, desc_helper.xy_desc(), desc_helper.param_desc(),
activation_desc.Get(), &size_in_bytes));
Singleton<CudnnHandlePool>::Get()->Put(handle);
return std::max(size_in_bytes, static_cast<size_t>(1));
}
class FusedNormalizationAddReluKernel final : public user_op::OpKernel,
public user_op::CudaGraphSupport {
public:
FusedNormalizationAddReluKernel() = default;
~FusedNormalizationAddReluKernel() override = default;
private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx) const override {
const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0);
auto* y = ctx->Tensor4ArgNameAndIndex("y", 0);
const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0);
const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0);
auto* moving_mean = ctx->Tensor4ArgNameAndIndex("moving_mean", 0);
auto* moving_variance = ctx->Tensor4ArgNameAndIndex("moving_variance", 0);
const auto axis = ctx->Attr<int32_t>("axis");
const auto epsilon = ctx->Attr<float>("epsilon");
const auto momentum = ctx->Attr<float>("momentum");
auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
auto* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0);
auto* reserve_space = ctx->Tensor4ArgNameAndIndex("reserve_space", 0);
auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const DataType data_type = x->data_type();
CHECK_EQ(x->shape_view(), y->shape_view());
CHECK_EQ(y->data_type(), data_type);
CHECK_GE(axis, 0);
CHECK_LT(axis, x->shape_view().NumAxes());
const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT);
desc_helper.CheckParamTensor(gamma);
desc_helper.CheckParamTensor(beta);
desc_helper.CheckParamTensor(moving_mean);
desc_helper.CheckParamTensor(moving_variance);
desc_helper.CheckParamTensor(mean);
desc_helper.CheckParamTensor(inv_variance);
CudnnActivationDesc activation_desc(HIPDNN_ACTIVATION_RELU, HIPDNN_PROPAGATE_NAN, 0);
hipdnnTensorDescriptor_t z_desc;
const void* z_ptr;
cudnnBatchNormOps_t ops;
if (ctx->has_input("addend", 0)) {
z_desc = desc_helper.xy_desc();
z_ptr = ctx->Tensor4ArgNameAndIndex("addend", 0)->dptr();
ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
} else {
z_desc = nullptr;
z_ptr = nullptr;
ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
}
size_t min_workspace_size;
OF_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, desc_helper.xy_desc(), z_desc, desc_helper.xy_desc(), desc_helper.param_desc(),
activation_desc.Get(), &min_workspace_size));
const size_t workspace_size = tmp_buffer->shape_view().elem_cnt();
CHECK_GE(workspace_size, min_workspace_size);
size_t min_reserve_space_size;
OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, activation_desc.Get(), desc_helper.xy_desc(), &min_reserve_space_size));
const size_t reserve_space_size = reserve_space->shape_view().elem_cnt();
CHECK_GE(reserve_space_size, min_reserve_space_size);
OF_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(),
z_desc, z_ptr, desc_helper.xy_desc(), y->mut_dptr(), desc_helper.param_desc(),
gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean->mut_dptr(),
moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), inv_variance->mut_dptr(),
activation_desc.Get(), tmp_buffer->mut_dptr(), workspace_size, reserve_space->mut_dptr(),
reserve_space_size));
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
REGISTER_USER_KERNEL("cudnn_fused_normalization_add_relu")
.SetCreateFn<FusedNormalizationAddReluKernel>()
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA))
.SetInferTmpSizeFn(InferFusedNormalizationAddReluTmpSize);
class FusedNormalizationAddReluGradUserKernel final : public user_op::OpKernel,
public user_op::CudaGraphSupport {
public:
FusedNormalizationAddReluGradUserKernel() = default;
~FusedNormalizationAddReluGradUserKernel() override = default;
private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx) const override {
const auto* x = ctx->Tensor4ArgNameAndIndex("x", 0);
const auto* y = ctx->Tensor4ArgNameAndIndex("y", 0);
auto* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
const auto* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
const auto* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0);
const auto* beta = ctx->Tensor4ArgNameAndIndex("beta", 0);
auto* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0);
auto* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0);
const auto* mean = ctx->Tensor4ArgNameAndIndex("mean", 0);
const auto* inv_variance = ctx->Tensor4ArgNameAndIndex("inv_variance", 0);
const auto* reserve_space = ctx->Tensor4ArgNameAndIndex("reserve_space", 0);
auto* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const auto axis = ctx->Attr<int32_t>("axis");
const auto epsilon = ctx->Attr<float>("epsilon");
const DataType data_type = x->data_type();
CHECK_EQ(dy->shape_view(), x->shape_view());
CHECK_EQ(dy->data_type(), data_type);
CHECK_EQ(dx->shape_view(), x->shape_view());
CHECK_EQ(dx->data_type(), data_type);
CHECK_GE(axis, 0);
CHECK_LT(axis, x->shape_view().NumAxes());
const CudnnTensorDescHelper desc_helper(x->shape_view(), data_type, axis,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT);
desc_helper.CheckParamTensor(gamma);
desc_helper.CheckParamTensor(beta);
desc_helper.CheckParamTensor(gamma_diff);
desc_helper.CheckParamTensor(beta_diff);
desc_helper.CheckParamTensor(mean);
desc_helper.CheckParamTensor(inv_variance);
CudnnActivationDesc activation_desc(HIPDNN_ACTIVATION_RELU, HIPDNN_PROPAGATE_NAN, 0);
hipdnnTensorDescriptor_t dz_desc;
void* dz_ptr;
cudnnBatchNormOps_t ops;
if (ctx->has_output("addend_diff", 0)) {
dz_desc = desc_helper.xy_desc();
dz_ptr = ctx->Tensor4ArgNameAndIndex("addend_diff", 0)->mut_dptr();
ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
} else {
dz_desc = nullptr;
dz_ptr = nullptr;
ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
}
size_t min_workspace_size;
OF_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, desc_helper.xy_desc(), desc_helper.xy_desc(), desc_helper.xy_desc(), dz_desc,
desc_helper.xy_desc(), desc_helper.param_desc(), activation_desc.Get(),
&min_workspace_size));
const size_t workspace_size = tmp_buffer->shape_view().elem_cnt();
CHECK_GE(workspace_size, min_workspace_size);
size_t min_reserve_space_size;
OF_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, activation_desc.Get(), desc_helper.xy_desc(), &min_reserve_space_size));
const size_t reserve_space_size = reserve_space->shape_view().elem_cnt();
CHECK_GE(reserve_space_size, min_reserve_space_size);
OF_CUDNN_CHECK(cudnnBatchNormalizationBackwardEx(
ctx->stream()->As<ep::CudaStream>()->cudnn_handle(), HIPDNN_BATCHNORM_SPATIAL_PERSISTENT,
ops, CudnnSPOnePtr(data_type), CudnnSPZeroPtr(data_type), CudnnSPOnePtr(data_type),
CudnnSPZeroPtr(data_type), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(),
y->dptr(), desc_helper.xy_desc(), dy->dptr(), dz_desc, dz_ptr, desc_helper.xy_desc(),
dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(),
gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr(),
activation_desc.Get(), tmp_buffer->mut_dptr(), workspace_size,
const_cast<void*>(reserve_space->dptr()), reserve_space_size));
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
REGISTER_USER_KERNEL("cudnn_fused_normalization_add_relu_grad")
.SetCreateFn<FusedNormalizationAddReluGradUserKernel>()
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA))
.SetInferTmpSizeFn(InferFusedNormalizationAddReluGradTmpSize);
#endif
} // namespace
} // namespace oneflow
......
......@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
CHECK_EQ_OR_RETURN(reserve_space_bits % split_num, 0);
reserve_space_bits = reserve_space_bits / split_num;
}
#ifdef WITH_ROCM
reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 64) / 64)}));
#else
reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 32) / 32)}));
#endif
return Maybe<void>::Ok();
})(ctx);
}
......@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,
user_op::TensorDesc* reserve_space) -> Maybe<void> {
const auto& x_desc = ctx->InputTensorDesc("x", 0);
#ifdef WITH_ROCM
reserve_space->set_shape(
Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 64) / 64)}));
#else
reserve_space->set_shape(
Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)}));
#endif
return Maybe<void>::Ok();
})(ctx);
}
......@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
/* static */ Maybe<void> NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) {
return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,
user_op::TensorDesc* reserve_space) -> Maybe<void> {
#ifdef WITH_ROCM
reserve_space->set_data_type(DataType::kInt64);
#else
reserve_space->set_data_type(DataType::kInt32);
#endif
return Maybe<void>::Ok();
})(ctx);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment