Commit 6046d8fb authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

dtk23.04

parent a715222c
......@@ -334,16 +334,16 @@ elseif(WIN32)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow")
endif()
if (BUILD_ROCM)
# AMD compiler fails to compile these three files with '-O1/2/3'.
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# so '-O0' will override '-O1/2/3'.
set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
PROPERTIES COMPILE_OPTIONS "-O0")
endif()
# if (BUILD_ROCM)
# # AMD compiler fails to compile these three files with '-O1/2/3'.
# # The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# # so '-O0' will override '-O1/2/3'.
# set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
# ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
# ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
# # ${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
# PROPERTIES COMPILE_OPTIONS "-O0")
# endif()
if(BUILD_CUDA)
string(JOIN "," CUDA_REAL_ARCHS ${CUDA_REAL_ARCHS_LIST})
......
......@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n);
template<typename T>
OF_DEVICE_FUNC T DeviceMin(T a, T b) {
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return a < b ? a : b;
#else
return std::min(a, b);
......@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) {
template<typename T>
OF_DEVICE_FUNC T DeviceMax(T a, T b) {
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return a > b ? a : b;
#else
return std::max(a, b);
......
......@@ -54,7 +54,7 @@ struct CastCASImpl {
}
};
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__)) || defined(WITH_ROCM)
template<typename T>
struct CastCASImpl<T, unsigned short int> {
......
......@@ -308,7 +308,8 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
count_shared[wid] = warp_count;
}
__syncthreads();
if (wid == 0) {
#ifdef WITH_ROCM
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
......@@ -318,10 +319,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
#ifdef WITH_ROCM
__syncthreads();
__syncthreads();
if (wid == 0) {
#else
__syncwarp();
if (wid == 0) {
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
warp_count = count_shared[lid];
} else {
warp_mean = static_cast<T>(0);
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
__syncwarp();
#endif
T block_mean = 0;
T block_m2 = 0;
......@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 128;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32;
#endif
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
......@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
......@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
......@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 1024;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32;
#endif
int grid_dim_x;
{
GPU(Error_t) err =
......@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
constexpr int block_size = 128;
#ifdef WITH_ROCM
constexpr int waves = 64;
#else
constexpr int waves = 32;
#endif
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
......@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF(64)
#else
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
......
......@@ -252,15 +252,21 @@ struct SetContext {
const int lane_age = ages[thread_ctx.lane_id];
const bool lane_hit = (lane_key == key && lane_age != 0);
#ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_hit);
const unsigned long long int hit_mask = __ballot(lane_hit);
if (hit_mask != 0) {
return __ffsll(static_cast<unsigned long long int>(hit_mask)) - 1;
} else {
return -1;
}
#else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_hit);
#endif
if (hit_mask != 0) {
return __ffs(static_cast<int>(hit_mask)) - 1;
} else {
return -1;
}
#endif
}
__device__ void Read(const LruCacheContext<Key, Elem>& cache_ctx, const ThreadContext& thread_ctx,
......@@ -277,15 +283,16 @@ struct SetContext {
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM
const unsigned hit_mask = __ballot(lane_key == key && lane_age != 0);
const unsigned long long int hit_mask = __ballot(lane_key == key && lane_age != 0);
#else
const unsigned hit_mask = __ballot_sync(kFullMask, lane_key == key && lane_age != 0);
#endif
if (hit_mask != 0) {
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
#ifdef WITH_ROCM
insert_way = __ffsll(static_cast<unsigned long long int>(hit_mask)) - 1;
const int insert_way_age = __shfl(lane_age, insert_way);
#else
insert_way = __ffs(static_cast<int>(hit_mask)) - 1;
const int insert_way_age = __shfl_sync(kFullMask, lane_age, insert_way);
#endif
if (lane_age > insert_way_age) {
......@@ -301,12 +308,16 @@ __syncwarp();
}
if (insert_way == -1) {
#ifdef WITH_ROCM
const unsigned valid_mask = __ballot(lane_age != 0);
const unsigned long long int valid_mask = __ballot(lane_age != 0);
#else
const unsigned valid_mask = __ballot_sync(kFullMask, lane_age != 0);
#endif
if (valid_mask != kFullMask) {
#ifdef WITH_ROCM
insert_way = __popcll(static_cast<unsigned long long int>(valid_mask));
#else
insert_way = __popc(static_cast<int>(valid_mask));
#endif
if (lane_age > 0) {
lane_age -= 1;
} else if (thread_ctx.lane_id == insert_way) {
......@@ -329,7 +340,8 @@ __syncwarp();
const Key lane_key = keys[thread_ctx.lane_id];
int lane_age = ages[thread_ctx.lane_id];
#ifdef WITH_ROCM
const int insert_way = __ffs(static_cast<int>(__ballot(lane_age == 1))) - 1;
unsigned long long int valid_mask_tmp = __ballot(lane_age == 1);
const int insert_way = __ffsll(valid_mask_tmp) - 1;
#else
const int insert_way = __ffs(__ballot_sync(kFullMask, lane_age == 1)) - 1;
#endif
......@@ -594,7 +606,8 @@ __syncwarp();
warp_keys[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_key;
warp_ages[thread_ctx.warp_id_in_block][thread_ctx.lane_id] = lane_age;
#ifdef WITH_ROCM
const int key_count = __popc(static_cast<int>(__ballot(lane_age != 0)));
unsigned long long int valid_mask_tmp = __ballot(lane_age != 0);
const int key_count = __popcll(valid_mask_tmp);
#else
const int key_count = __popc(__ballot_sync(kFullMask, lane_age != 0));
#endif
......
......@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/include/primitive//broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
......
......@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
OF_DEVICE_FUNC half operator()(half src) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const float tanh_in =
__half2float(__float2half_rn(alpha) * (src + __float2half_rn(beta) * src * src * src));
const float tanh_out = unary_functor_internal::TanhApprox(tanh_in);
......@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* dst, const half* src) const {
const half2 src2 = *(reinterpret_cast<const half2*>(src));
const float2 tanh_in = __half22float2(__hmul2(
......
......@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() {
"reduce_sum_like",
"layer_norm_grad",
"layer_norm",
"layer_norm_param_grad"
};
return black_list;
}
"layer_norm_param_grad",
const AMPList& AutoMixedPrecisionLists::GrayList() {
static AMPList gray_list = {"add_n",
"add_n",
"tf_avg_pool_1d",
"tf_avg_pool_1d_grad",
"tf_avg_pool_2d",
"tf_avg_pool_2d_grad",
"tf_avg_pool_3d",
"tf_avg_pool_3d_grad",
"avg_pool_1d",
"avg_pool_1d_grad",
"avg_pool_2d",
"avg_pool_2d_grad",
"avg_pool_3d",
"avg_pool_3d_grad",
// "avg_pool_1d",
// "avg_pool_1d_grad",
// "avg_pool_2d",
// "avg_pool_2d_grad",
// "avg_pool_3d",
// "avg_pool_3d_grad",
"bias_add",
"sigmoid_grad",
"tanh",
......@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"group_norm_grad",
"silu",
"silu_grad",
"fused_weighted_sum"};
return gray_list;
}
"fused_weighted_sum",
const AMPList& AutoMixedPrecisionLists::ClearList() {
// TODO(niuchong): tuple_identity
static AMPList clear_list = {"broadcast_like",
"broadcast_like",
"gather",
"gather_nd",
"scatter_nd",
......@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"tf_max_pool_2d_grad",
"tf_max_pool_3d",
"tf_max_pool_3d_grad",
"max_pool_1d",
"max_pool_1d_grad",
"max_pool_2d",
"max_pool_2d_grad",
"max_pool_3d",
"max_pool_3d_grad",
// "max_pool_1d",
// "max_pool_1d_grad",
// "max_pool_2d",
// "max_pool_2d_grad",
// "max_pool_3d",
// "max_pool_3d_grad",
"reshape",
"reshape_like",
"relu",
......@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"pinned_identity",
"to_contiguous",
"copy",
"upsample_nearest_2d"};
"upsample_nearest_2d"
};
return black_list;
}
const AMPList& AutoMixedPrecisionLists::GrayList() {
static AMPList gray_list = {
// "add_n",
// "tf_avg_pool_1d",
// "tf_avg_pool_1d_grad",
// "tf_avg_pool_2d",
// "tf_avg_pool_2d_grad",
// "tf_avg_pool_3d",
// "tf_avg_pool_3d_grad",
"avg_pool_1d",
"avg_pool_1d_grad",
"avg_pool_2d",
"avg_pool_2d_grad",
"avg_pool_3d",
"avg_pool_3d_grad"
// "bias_add",
// "sigmoid_grad",
// "tanh",
// "tanh_grad",
// "sqrt",
// "sqrt_grad",
// "scalar_mul",
// "scalar_mul_by_tensor",
// "scalar_add",
// "scalar_div",
// "scalar_pow",
// "broadcast_add",
// "broadcast_sub",
// "broadcast_mul",
// "broadcast_div",
// "rms_norm",
// "rms_norm_grad",
// "rms_norm_param_grad",
// "dropout",
// "dropout_grad",
// "softmax",
// "softmax_grad",
// "log_softmax",
// "log_softmax_grad",
// "gelu",
// "gelu_grad",
// "fast_gelu",
// "fast_gelu_grad",
// "normalization",
// "normalization_grad",
// "normalization_add_relu",
// "normalization_add_relu_grad",
// "sparse_softmax_cross_entropy",
// "sparse_softmax_cross_entropy_grad",
// "nll",
// "nll_grad",
// "fused_tril_scale_softmax_mask_scale",
// "fused_tril_scale_softmax_mask_scale_grad",
// "fused_scale_mask_softmax_dropout",
// "fused_scale_mask_softmax_dropout_grad",
// "fused_scale_mask_softmax",
// "fused_scale_mask_softmax_grad",
// "fused_bias_add_scale_mask_softmax_dropout",
// "fused_bias_add_gelu",
// "fused_bias_add_gelu_grad",
// "fused_bias_add_mask_scale",
// "fused_fast_gelu_mul",
// "fused_fast_gelu_mul_grad",
// "acc",
// "reciprocal",
// "reciprocal_no_nan",
// "group_norm",
// "group_norm_param_grad",
// "group_norm_grad",
// "silu",
// "silu_grad",
// "fused_weighted_sum"
};
return gray_list;
}
const AMPList& AutoMixedPrecisionLists::ClearList() {
// TODO(niuchong): tuple_identity
static AMPList clear_list = {
// "broadcast_like",
// "gather",
// "gather_nd",
// "scatter_nd",
// "scatter_nd_like",
// "unsorted_segment_sum_like",
// "tf_max_pool_1d",
// "tf_max_pool_1d_grad",
// "tf_max_pool_2d",
// "tf_max_pool_2d_grad",
// "tf_max_pool_3d",
// "tf_max_pool_3d_grad",
"max_pool_1d",
"max_pool_1d_grad",
"max_pool_2d",
"max_pool_2d_grad",
"max_pool_3d",
"max_pool_3d_grad"
// "reshape",
// "reshape_like",
// "relu",
// "relu_grad",
// "transpose",
// "random_mask_like",
// "concat",
// "split_like",
// "pad",
// "same_padding",
// "same_padding_grad",
// "tril",
// "slice",
// "slice_grad",
// "fused_scale_tril",
// "identity",
// "squeeze",
// "embedding",
// "embedding_grad",
// "expand",
// "expand_dims",
// "cast_to_static_shape",
// "parallel_cast",
// "hierarchical_parallel_cast",
// "hierarchical_parallel_cast_like",
// "repeat",
// "unpack",
// "pack",
// "nvtx_start",
// "nvtx_end",
// "narrow",
// "narrow_grad",
// "ones_like",
// "pinned_identity",
// "to_contiguous",
// "copy",
// "upsample_nearest_2d"
};
return clear_list;
}
......
......@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> {
OF_DEVICE_FUNC FusedFastGeluMulFunctor() {}
OF_DEVICE_FUNC half operator()(const half x, const half m) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const float tanh_in =
__half2float(__float2half_rn(alpha) * (x + __float2half_rn(beta) * x * x * x));
const float tanh_out = TanhApprox(tanh_in);
......@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* y, const half* x, const half* m) const {
const half2 x2 = *(reinterpret_cast<const half2*>(x));
const float2 tanh_in = __half22float2(
......@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> {
__device__ void operator()(half& x_diff, half& m_diff, const half& dy, const half& x,
const half& m) const {
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
const half halpha = __float2half_rn(alpha);
const half hbeta = __float2half_rn(beta);
const half hone = __float2half_rn(1.0F);
......@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000) || defined(WITH_ROCM)
__device__ void Apply2(half* x_diff, half* m_diff, const half* dy, const half* x,
const half* m) const {
const half2 dy2 = *(reinterpret_cast<const half2*>(dy));
......
......@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
constexpr int kReduceBlockSize = 512;
constexpr int kBlockSize = 128;
#ifdef WITH_ROCM
constexpr int kNumWaves = 64;
#else
constexpr int kNumWaves = 32;
#endif
inline GPU(Error_t) GetReduceNumBlocks(int64_t n, int* num_blocks) {
int dev;
......
......@@ -28,7 +28,7 @@ limitations under the License.
#elif defined(__HIPCC__)
#include <hip/hsa_detail/math_functions.h>
#include <hip/math_functions.h>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
......
......@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
CHECK_EQ_OR_RETURN(reserve_space_bits % split_num, 0);
reserve_space_bits = reserve_space_bits / split_num;
}
#ifdef WITH_ROCM
reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 64) / 64)}));
#else
reserve_space->set_shape(Shape({static_cast<int64_t>(RoundUp(reserve_space_bits, 32) / 32)}));
#endif
return Maybe<void>::Ok();
})(ctx);
}
......@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,
user_op::TensorDesc* reserve_space) -> Maybe<void> {
const auto& x_desc = ctx->InputTensorDesc("x", 0);
#ifdef WITH_ROCM
reserve_space->set_shape(
Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 64) / 64)}));
#else
reserve_space->set_shape(
Shape({static_cast<int64_t>(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)}));
#endif
return Maybe<void>::Ok();
})(ctx);
}
......@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
/* static */ Maybe<void> NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) {
return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x,
user_op::TensorDesc* reserve_space) -> Maybe<void> {
#ifdef WITH_ROCM
reserve_space->set_data_type(DataType::kInt64);
#else
reserve_space->set_data_type(DataType::kInt32);
#endif
return Maybe<void>::Ok();
})(ctx);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment