Commit 1a91fcc2 authored by gaoqiong's avatar gaoqiong
Browse files

add dtk所需文件

parent a144865d
Pipeline #492 failed with stages
in 0 seconds
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/cpu/tensor/concatbase.h"
namespace onnxruntime {
namespace rocm {
class Concat final : public RocmKernel, public ConcatBase {
public:
Concat(const OpKernelInfo& info) : RocmKernel(info), ConcatBase(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/tensor/concat_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
namespace onnxruntime {
namespace rocm {
namespace {
#ifdef USE_ROCM
constexpr int kNumElementsPerThread = 2;
constexpr int kNumThreadsPerBlock = 512;
#else
constexpr int kNumElementsPerThread = GridDim::maxElementsPerThread;
constexpr int kNumThreadsPerBlock = GridDim::maxThreadsPerBlock;
#endif
} // namespace
// concat dimension are same for all inputs
template <typename T, typename InputDataArray>
__global__ void _ConcatKernelSameConcatDim(const fast_divmod block_size_including_axis_dim_div,
const fast_divmod block_size_inside_axis_dim_div,
const fast_divmod concat_dim_size, T* output_data, InputDataArray input_data,
const HIP_LONG N) {
HIP_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x;
T value[kNumElementsPerThread];
HIP_LONG id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
int outer_block_index, block_index, offset, input_index, block_offset;
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
concat_dim_size.divmod(block_index, input_index, block_offset);
HIP_LONG input_pos =
(outer_block_index * concat_dim_size.d_ + block_offset) * block_size_inside_axis_dim_div.d_ + offset;
value[i] = reinterpret_cast<const T*>(input_data[input_index])[input_pos];
id += kNumThreadsPerBlock;
}
}
id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
output_data[id] = value[i];
id += kNumThreadsPerBlock;
}
}
}
template <typename InputDataArray>
Status ConcatSameConcatDimImpl(hipStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t concat_size, void* output_data,
const InputDataArray input_data, const size_t output_size) {
HIP_LONG N = static_cast<HIP_LONG>(output_size);
int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock);
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
fast_divmod concat_dim_size = fast_divmod(static_cast<int>(concat_size));
switch (element_bytes) {
#define CASE_ELEMENT_TYPE(type) \
case sizeof(type): { \
hipLaunchKernelGGL(_ConcatKernelSameConcatDim, blocksPerGrid, kNumThreadsPerBlock, 0, stream, \
block_size_including_axis_dim_div, block_size_inside_axis_dim_div, concat_dim_size, \
reinterpret_cast<ToHipType<type>::MappedType*>(output_data), input_data, N); \
} break
CASE_ELEMENT_TYPE(int8_t);
CASE_ELEMENT_TYPE(int16_t);
CASE_ELEMENT_TYPE(int32_t);
CASE_ELEMENT_TYPE(int64_t);
#undef CASE_ELEMENT_TYPE
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Concat operator");
}
return Status::OK();
}
// input tensors addresses in device memory
template Status ConcatSameConcatDimImpl<const void**>(hipStream_t stream, const size_t element_bytes,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t concat_size,
void* output_data, const void** input_data,
const size_t output_size);
// input tensor addresses passed by value
template Status ConcatSameConcatDimImpl<TArray<const void*, 32>>(hipStream_t stream, const size_t element_bytes,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t concat_size, void* output_data,
TArray<const void*, 32> input_data,
const size_t output_size);
template <typename T>
__global__ void _ConcatKernel(const fast_divmod block_size_including_axis_dim_div,
const fast_divmod block_size_inside_axis_dim_div, const int64_t* concat_sizes,
const int64_t* concat_sizes_range, const int64_t* axis_dimension_input_output_mapping,
T* output_data, const void** input_data, const HIP_LONG N) {
HIP_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x;
T value[kNumElementsPerThread];
HIP_LONG id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
int outer_block_index, block_index, offset;
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
int input_index = axis_dimension_input_output_mapping[block_index];
int64_t range_left = (input_index == 0) ? 0 : concat_sizes_range[input_index - 1];
int block_offset = block_index - static_cast<int>(range_left);
HIP_LONG input_pos =
(outer_block_index * concat_sizes[input_index] + block_offset) * block_size_inside_axis_dim_div.d_ + offset;
value[i] = reinterpret_cast<const T*>(input_data[input_index])[input_pos];
id += kNumThreadsPerBlock;
}
}
id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
output_data[id] = value[i];
id += kNumThreadsPerBlock;
}
}
}
Status ConcatImpl(hipStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t* concat_sizes, const int64_t* concat_sizes_range,
const int64_t* axis_dimension_input_output_mapping, void* output_data, const void** input_data,
const size_t output_size) {
HIP_LONG N = static_cast<HIP_LONG>(output_size);
int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock);
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
switch (element_bytes) {
#define CASE_ELEMENT_TYPE(type) \
case sizeof(type): { \
hipLaunchKernelGGL(_ConcatKernel, blocksPerGrid, kNumThreadsPerBlock, 0, stream, \
block_size_including_axis_dim_div, block_size_inside_axis_dim_div, concat_sizes, concat_sizes_range, \
axis_dimension_input_output_mapping, reinterpret_cast<ToHipType<type>::MappedType*>(output_data), input_data, \
N); \
} break;
CASE_ELEMENT_TYPE(int8_t);
CASE_ELEMENT_TYPE(int16_t);
CASE_ELEMENT_TYPE(int32_t);
CASE_ELEMENT_TYPE(int64_t);
#undef CASE_ELEMENT_TYPE
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Concat operator");
}
return Status::OK();
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/common/common.h"
namespace onnxruntime {
namespace rocm {
template <typename InputDataArray>
Status ConcatSameConcatDimImpl(hipStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t concat_size, void* output_data,
const InputDataArray input_data, const size_t output_size);
Status ConcatImpl(hipStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t* concat_sizes, const int64_t* concat_sizes_range,
const int64_t* axis_dimension_input_output_mapping, void* output_data, const void** input_data,
const size_t output_size);
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "expand.h"
#include "expand_impl.h"
#include "core/providers/cpu/tensor/utils.h"
namespace onnxruntime {
namespace rocm {
namespace {
// Logically expanded y could just be a view of x.
static void CalcEffectiveDims(TensorShapeVector& x_dims, TensorShapeVector& y_dims) {
TensorShapeVector x_reverse;
TensorShapeVector y_reverse;
int64_t xi = gsl::narrow_cast<int64_t>(x_dims.size()) - 1;
for (int64_t yi = gsl::narrow_cast<int64_t>(y_dims.size()) - 1; yi >= 0; --yi, --xi) {
int64_t xdim = (xi >= 0) ? x_dims[xi] : 1;
int64_t ydim = y_dims[yi];
if (xdim == ydim || xdim == 1) {
x_reverse.push_back(xdim);
y_reverse.push_back(ydim);
} else { // xdim < ydim && xdim > 1, split
ydim /= xdim;
x_reverse.push_back(xdim);
y_reverse.push_back(xdim);
x_reverse.push_back(1);
y_reverse.push_back(ydim);
}
}
x_dims.clear();
y_dims.clear();
x_dims.push_back(1);
y_dims.push_back(1);
// compact the dims, remove (x=1, y=1), merge (x=1, y1*y2...)
for (int64_t i = gsl::narrow_cast<int64_t>(y_reverse.size()) - 1; i >= 0; --i) {
if (x_reverse[i] == 1) {
if (y_reverse[i] == 1) {
continue;
}
if (x_dims.back() == 1) {
y_dims.back() *= y_reverse[i];
} else {
x_dims.push_back(1);
y_dims.push_back(y_reverse[i]);
}
} else { // x_reverse[i] == y_reverse[i]
if (x_dims.back() == y_dims.back()) {
x_dims.back() *= x_reverse[i];
y_dims.back() *= y_reverse[i];
} else {
x_dims.push_back(x_reverse[i]);
y_dims.push_back(y_reverse[i]);
}
}
}
}
#ifdef ENABLE_TRAINING
TensorShapeVector ComputeOutputStrides(const TensorShape& input_shapes, const gsl::span<const int64_t>& input_strides,
const TensorShape& output_shapes) {
const size_t rank = output_shapes.NumDimensions();
const size_t input_rank = input_shapes.NumDimensions();
if (input_rank == 0 || input_shapes.Size() == 1) {
return TensorShapeVector(rank, 0);
}
TensorShapeVector output_strides(rank);
const size_t offset = rank - input_rank;
for (size_t dim = rank - 1;; --dim) {
int64_t stride = 0;
int64_t input_dim_size = dim >= offset ? input_shapes[dim - offset] : 1;
if (input_dim_size == output_shapes[dim]) {
stride = dim >= offset ? input_strides[dim - offset] : output_shapes[dim + 1] * output_strides[dim + 1];
}
output_strides[dim] = stride;
if (dim == 0) break;
}
return output_strides;
}
#endif
} // namespace
Status Expand::ComputeInternal(OpKernelContext* ctx) const {
const auto& input_data_tensor = *ctx->Input<Tensor>(0);
const auto& input_shape_tensor = *ctx->Input<Tensor>(1);
// new shape to be expanded to
const auto* p_shape = input_shape_tensor.Data<int64_t>();
TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor.Shape().Size()};
TensorShape output_shape(output_dims);
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape));
auto& output_tensor = *ctx->Output(0, output_shape);
if (0 == output_shape.Size()) {
return Status::OK();
}
#ifdef ENABLE_TRAINING
// Strided output.
if (input_data_tensor.DataRaw() == output_tensor.DataRaw()) {
gsl::span<const int64_t> input_strides = input_data_tensor.Strides();
TensorShapeVector output_strides =
ComputeOutputStrides(input_data_tensor.Shape(), input_strides, output_shape);
output_tensor.SetShapeAndStrides(output_shape, output_strides);
return Status::OK();
}
#endif
output_dims = output_shape.AsShapeVector();
auto input_dims = input_data_tensor.Shape().AsShapeVector();
CalcEffectiveDims(input_dims, output_dims);
int rank = gsl::narrow_cast<int>(output_dims.size());
TensorPitches original_input_strides(input_dims);
TensorPitches original_output_strides(output_dims);
TArray<int64_t> input_strides(rank);
for (auto i = 0; i < rank; i++) {
input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i];
}
TArray<fast_divmod> output_strides(rank);
for (auto i = 0; i < rank; i++) {
output_strides[i] = fast_divmod(static_cast<int>(original_output_strides[i]));
}
return ExpandImpl(
Stream(),
input_data_tensor.DataType()->Size(),
gsl::narrow_cast<int>(output_shape.Size()),
gsl::narrow_cast<int>(input_data_tensor.Shape().Size()),
input_data_tensor.DataRaw(),
output_tensor.MutableDataRaw(),
output_strides,
input_strides);
}
#ifdef ENABLE_TRAINING
#define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedOutput(0, 0)
#else
#define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create())
#endif
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Expand, kOnnxDomain, 8, 12, kRocmExecutionProvider,
CREATE_EXPAND_KERNEL_DEF.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.InputMemoryType(OrtMemTypeCPUInput, 1),
Expand);
ONNX_OPERATOR_KERNEL_EX(Expand, kOnnxDomain, 13, kRocmExecutionProvider,
CREATE_EXPAND_KERNEL_DEF.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.InputMemoryType(OrtMemTypeCPUInput, 1),
Expand);
#undef CREATE_EXPAND_KERNEL_DEF
} // namespace rocm
}; // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace rocm {
class Expand final : public RocmKernel {
public:
Expand(const OpKernelInfo& info) : RocmKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
Status ComputeOutputShape(
const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape);
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/cu_inc/common.cuh"
#include "expand_impl.h"
#include "core/providers/rocm/rocm_common.h"
namespace onnxruntime {
namespace rocm {
template <typename T, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _FillFromDataPtrKernel(T* output_data, const T* input_data, HIP_LONG N) {
HIP_LONG id = NumElementsPerThread * blockDim.x * blockIdx.x + threadIdx.x;
T val = *input_data;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
output_data[id] = val;
id += NumThreadsPerBlock;
}
}
}
template <typename T>
void FillFromDataPtr(hipStream_t stream, T* output_data, const T* input_data, int64_t count) {
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
HIP_LONG N = static_cast<HIP_LONG>(count);
hipLaunchKernelGGL(HIP_KERNEL_NAME(_FillFromDataPtrKernel<T, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread>), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream, output_data, input_data, N);
}
template <typename T>
__global__ void ExpandKernel2D(
const int N,
const T* input_data,
T* output_data,
const fast_divmod fdm_output_stride0,
const int input_view_stride0,
const int input_view_stride1) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
int dim0, dim1;
fdm_output_stride0.divmod(id, dim0, dim1);
output_data[id] = input_data[dim0 * input_view_stride0 + dim1 * input_view_stride1];
}
template <typename T, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void ExpandKernel(
const int rank,
const int N,
const T* input_data,
T* output_data,
const TArray<fast_divmod> output_strides,
const TArray<int64_t> input_strides) {
HIP_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
T value[NumElementsPerThread];
HIP_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
// compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
HIP_LONG index = 0;
HIP_LONG offset = id;
#pragma unroll
for (auto dim = 0; dim < output_strides.Capacity(); dim++) {
if (dim >= rank) {
break;
}
int q, r;
output_strides[dim].divmod(offset, q, r);
index += static_cast<int>(input_strides[dim]) * q;
offset = r;
}
value[i] = input_data[index];
id += NumThreadsPerBlock;
}
}
id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
output_data[id] = value[i];
id += NumThreadsPerBlock;
}
}
}
Status ExpandByFill(hipStream_t stream, const size_t element_size, const int N, const void* input_data, void* output_data) {
#define EXPAND_FILL_ON(TYPE) \
case sizeof(TYPE): \
FillFromDataPtr(stream, \
reinterpret_cast<TYPE*>(output_data), \
reinterpret_cast<const TYPE*>(input_data), \
static_cast<int64_t>(N)); \
break
switch (element_size) {
EXPAND_FILL_ON(int8_t);
EXPAND_FILL_ON(int16_t);
EXPAND_FILL_ON(int32_t);
EXPAND_FILL_ON(int64_t);
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Expand operator");
}
return Status::OK();
}
Status Expand2D(
hipStream_t stream,
const size_t element_size,
const int N,
const void* input_data,
void* output_data,
const fast_divmod fdm_output_stride0,
const int input_view_stride0,
const int input_view_stride1) {
#define EXPAND2D_ON(TYPE) \
case sizeof(TYPE): \
hipLaunchKernelGGL(ExpandKernel2D, blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream, \
N, reinterpret_cast<const TYPE*>(input_data), reinterpret_cast<TYPE*>(output_data), \
fdm_output_stride0, input_view_stride0, input_view_stride1); \
break
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(N, GridDim::maxThreadsPerBlock));
switch (element_size) {
EXPAND2D_ON(int8_t);
EXPAND2D_ON(int16_t);
EXPAND2D_ON(int32_t);
EXPAND2D_ON(int64_t);
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Expand operator");
}
return Status::OK();
}
Status ExpandImpl(
hipStream_t stream,
const size_t element_size,
const int N_output,
const int N_input,
const void* input_data,
void* output_data,
const TArray<fast_divmod>& output_strides,
const TArray<int64_t>& input_strides) {
const int rank = static_cast<int>(output_strides.Size());
if (rank == 1) {
if (N_input == N_output) {
HIP_RETURN_IF_ERROR(hipMemcpyAsync(output_data, input_data, N_output * element_size, hipMemcpyDeviceToDevice, stream));
} else { // N_input == 1
return ExpandByFill(stream, element_size, N_output, input_data, output_data);
}
} else if (rank == 2) {
return Expand2D(stream, element_size, N_output, input_data, output_data,
output_strides[0],
static_cast<int>(input_strides[0]),
static_cast<int>(input_strides[1]));
}
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(N_output, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
#define EXPAND_ON(TYPE) \
case sizeof(TYPE): \
ExpandKernel<TYPE, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread> \
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>( \
rank, N_output, reinterpret_cast<const TYPE*>(input_data), reinterpret_cast<TYPE*>(output_data), \
output_strides, input_strides); \
break
switch (element_size) {
EXPAND_ON(uint8_t);
EXPAND_ON(uint16_t);
EXPAND_ON(uint32_t);
EXPAND_ON(uint64_t);
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Expand operator");
}
return Status::OK();
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/rocm_common.h"
namespace onnxruntime {
namespace rocm {
Status ExpandImpl(
hipStream_t stream,
const size_t element_size,
const int N_output,
const int N_input,
const void* input_data,
void* output_data,
const TArray<fast_divmod>& output_strides,
const TArray<int64_t>& input_strides);
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "eye_like.h"
#include "eye_like_impl.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/rocm/shared_inc/fast_divmod.h"
using namespace onnxruntime::common;
namespace onnxruntime {
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(
EyeLike,
kOnnxDomain,
9,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<int32_t>()})
.TypeConstraint("T2", std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<int32_t>()}),
EyeLike);
#define TYPED_FUNCTION_CALL(T) \
EyeLikeImpl<typename ToHipType<T>::MappedType>( \
Stream(), \
offset, \
dim1 + 1, \
reinterpret_cast<typename ToHipType<T>::MappedType*>(T2->MutableData<T>()), \
diag_count); \
break;
Status EyeLike::ComputeInternal(OpKernelContext* context) const {
const auto* T1 = context->Input<Tensor>(0);
ORT_ENFORCE(T1 != nullptr);
auto input_dims = T1->Shape().GetDims();
if (input_dims.size() != 2) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "EyeLike : Input tensor dimension is not 2");
}
// set output tensor shape same as input tensor and set all values to zero
auto* T2 = context->Output(0, input_dims);
HIP_RETURN_IF_ERROR(hipMemsetAsync(T2->MutableDataRaw(), 0, T2->SizeInBytes(), Stream()));
auto dim0 = input_dims[0];
auto dim1 = input_dims[1];
if ((k_ >= 0 && k_ >= dim1) || (k_ < 0 && (std::abs(k_)) >= dim0)) {
return Status::OK();
}
// Calculate the start offset and total number of elements in diagnal.
size_t offset, diag_count;
if (k_ >= 0) {
offset = k_;
diag_count = std::min(dim1 - k_, dim0);
} else {
offset = (-k_) * dim1;
diag_count = std::min(dim0 + k_, dim1);
}
auto output_tensor_dtype = has_dtype_ ? static_cast<ONNX_NAMESPACE::TensorProto_DataType>(dtype_) : T1->GetElementType();
switch (output_tensor_dtype) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
TYPED_FUNCTION_CALL(float)
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
TYPED_FUNCTION_CALL(double)
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
TYPED_FUNCTION_CALL(int32_t)
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
TYPED_FUNCTION_CALL(uint64_t)
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
TYPED_FUNCTION_CALL(int64_t)
default:
ORT_THROW("Unsupported 'dtype' value: ", output_tensor_dtype);
}
return Status::OK();
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace rocm {
class EyeLike final : public RocmKernel {
public:
EyeLike(const OpKernelInfo& info) : RocmKernel(info) {
if (!info.GetAttr("k", &k_).IsOK()) {
k_ = 0;
}
has_dtype_ = info.GetAttr("dtype", &dtype_).IsOK();
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
bool has_dtype_;
int64_t dtype_;
int64_t k_;
};
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "eye_like_impl.h"
namespace onnxruntime {
namespace rocm {
template <typename T>
__global__ void _EyeLikeKernel(
size_t offset,
size_t stripe,
T* output_data,
HIP_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
// offset is the first elements, stripe is width + 1.
output_data[offset + id * stripe] = static_cast<T>(1);
}
template <typename T>
void EyeLikeImpl(
hipStream_t stream,
size_t offset,
size_t stripe,
T* output_data,
size_t diag_count) {
constexpr int block_size = 256;
int blocksPerGrid = (int)(ceil(static_cast<float>(diag_count) / block_size));
HIP_LONG N = static_cast<HIP_LONG>(diag_count);
hipLaunchKernelGGL(_EyeLikeKernel, blocksPerGrid, block_size, 0, stream, offset, stripe, output_data, N);
}
#define SPECIALIZED_IMPL(T) \
template void EyeLikeImpl<T>( \
hipStream_t stream, \
size_t offset, \
size_t stripe, \
T* output_data, \
size_t diag_count);
SPECIALIZED_IMPL(int32_t)
SPECIALIZED_IMPL(int64_t)
SPECIALIZED_IMPL(uint64_t)
SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
} // namespace rocm
} // namespace onnxruntime
\ No newline at end of file
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/shared_inc/fast_divmod.h"
#include "core/common/common.h"
namespace onnxruntime {
namespace rocm {
template <typename T>
void EyeLikeImpl(
hipStream_t stream,
size_t offset, // offset of first element in diagnal
size_t stripe, // stripe, here it's width + 1
T* output_data, // output buffer
size_t diag_count // total number of elements in diagnal
);
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "flatten.h"
namespace onnxruntime {
namespace rocm {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
1, 8,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
9, 10,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Flatten);
// explicitly support negative axis
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
11, 12,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Flatten);
ONNX_OPERATOR_KERNEL_EX(
Flatten,
kOnnxDomain,
13,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Flatten);
Status Flatten::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* X = ctx->Input<Tensor>(0);
const TensorShape& X_shape = X->Shape();
auto axis = axis_;
// Valid axis range is [-rank, rank] instead of [-rank, rank-1], add additional check to only handle neg axis case.
if (axis < 0) {
axis = HandleNegativeAxis(axis, X_shape.NumDimensions()); // handle negative and enforce axis is valid
}
ORT_ENFORCE(gsl::narrow_cast<int64_t>(X_shape.NumDimensions()) >= axis, "The rank of input tensor must be >= axis");
Tensor* Y = ctx->Output(0, {X_shape.SizeToDimension(axis), X_shape.SizeFromDimension(axis)});
//If source and target pointers are not equal (non-inplace operation), we need to copy the data.
const void* source = X->DataRaw();
void* target = Y->MutableDataRaw();
if (target != source) {
HIP_RETURN_IF_ERROR(hipMemcpyAsync(target, source, X_shape.Size() * X->DataType()->Size(),
hipMemcpyDeviceToDevice, Stream()));
}
return Status::OK();
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace rocm {
class Flatten final : public RocmKernel {
public:
Flatten(const OpKernelInfo& info) : RocmKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
int64_t axis_;
};
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/tensor/gather_impl.h"
#include "core/providers/rocm/tensor/gather.h"
#include "core/providers/cpu/tensor/utils.h"
namespace onnxruntime {
namespace rocm {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gather,
kOnnxDomain,
1, 10,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Gather);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gather,
kOnnxDomain,
11, 12,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Gather);
// explicit negative axis support
ONNX_OPERATOR_KERNEL_EX(
Gather,
kOnnxDomain,
13,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Gather);
Status Gather::ComputeInternal(OpKernelContext* context) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
const TensorShape& input_shape = p.input_tensor->Shape();
const int64_t block_size = input_shape.SizeFromDimension(p.axis + 1);
size_t N = p.indices_tensor->Shape().Size();
const int64_t input_block_size = input_shape.SizeFromDimension(p.axis);
const int64_t output_block_size = N * block_size;
const int64_t indices_max = input_shape[p.axis];
const void* input_data = p.input_tensor->DataRaw();
const void* indices_data = p.indices_tensor->DataRaw();
void* output_data = p.output_tensor->MutableDataRaw();
if (p.output_tensor->Shape().Size() == 0) {
return Status::OK();
}
const fast_divmod divmod_output_block_size(gsl::narrow_cast<int>(output_block_size));
const fast_divmod divmod_block_size(gsl::narrow_cast<int>(block_size));
const size_t element_size = p.input_tensor->DataType()->Size();
const size_t index_element_size = p.indices_tensor->DataType()->Size();
// ROCM Kernel implementation supports element sizes of:
// int8_t, int16_t, int32_t and int64_t which covers all supported
// types since there is no computations necessary just data movement
if (p.indices_tensor->IsDataType<int32_t>() ||
p.indices_tensor->IsDataType<int64_t>()) {
GatherImpl(
Stream(),
input_block_size,
indices_max,
divmod_output_block_size,
divmod_block_size,
indices_data,
index_element_size,
input_data,
element_size,
output_data,
p.output_tensor->Shape().Size());
return Status::OK();
}
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tind not supported yet in Gather.");
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/cpu/tensor/gatherbase.h"
namespace onnxruntime {
namespace rocm {
class Gather final : public RocmKernel, public GatherBase {
public:
Gather(const OpKernelInfo& info) : RocmKernel(info), GatherBase(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/tensor/gather_elements.h"
#include "core/providers/rocm/tensor/gather_elements_impl.h"
#include "core/providers/cpu/tensor/utils.h"
namespace onnxruntime {
namespace rocm {
// Ideally both input and indices can support strided tensor, for training case, the indices is the input for both
// GatherElements and GatherElementsGrad, indices supporting strided tensor is more useful for saving memory.
// So we only mark indices as MayStridedInput for now. Will do this for input once needed.
#ifdef ENABLE_TRAINING
#define CREATE_GATHER_ELEMENTS_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedInput(1)
#else
#define CREATE_GATHER_ELEMENTS_KERNEL_DEF (*KernelDefBuilder::Create())
#endif
ONNX_OPERATOR_KERNEL_EX(GatherElements, kOnnxDomain, 13, kRocmExecutionProvider,
CREATE_GATHER_ELEMENTS_KERNEL_DEF.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
GatherElements);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(GatherElements, kOnnxDomain, 11, 12, kRocmExecutionProvider,
CREATE_GATHER_ELEMENTS_KERNEL_DEF
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind",
std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
GatherElements);
#undef CREATE_GATHER_ELEMENTS_KERNEL_DEF
void CoalesceDimensions(TensorShapeVector& input_shape, TensorShapeVector& indices_shape,
TensorShapeVector* p_indices_strides, int64_t axis, GatherScatterElementsArgs& args) {
size_t rank = input_shape.size();
if (axis < 0 || axis >= static_cast<int64_t>(rank)) ORT_THROW("Invalid axis in CoalesceDimensions: ", axis);
size_t new_axis = static_cast<size_t>(axis);
auto CanCoalesce = [&](size_t dst, size_t src) {
if (dst == static_cast<size_t>(new_axis) || src == static_cast<size_t>(new_axis)) return false;
if (input_shape[dst] == 1 && indices_shape[dst] == 1) return true;
if (input_shape[src] == 1 && indices_shape[src] == 1) return true;
return input_shape[dst] == indices_shape[dst] && input_shape[src] == indices_shape[src] &&
(!p_indices_strides || (*p_indices_strides)[dst] == indices_shape[src] * (*p_indices_strides)[src]);
};
size_t curr = 0;
for (size_t next = 1; next < rank; ++next) {
if (CanCoalesce(curr, next)) {
if (indices_shape[next] != 1 && p_indices_strides) {
(*p_indices_strides)[curr] = (*p_indices_strides)[next];
}
input_shape[curr] *= input_shape[next];
indices_shape[curr] *= indices_shape[next];
} else {
if (next == static_cast<size_t>(new_axis)) {
// Handle all dims outside of axis are 1-dim.
if (input_shape[curr] != 1 || indices_shape[curr] != 1) ++curr;
new_axis = static_cast<int64_t>(curr);
} else {
++curr;
}
if (curr != next) {
input_shape[curr] = input_shape[next];
indices_shape[curr] = indices_shape[next];
if (p_indices_strides) (*p_indices_strides)[curr] = (*p_indices_strides)[next];
}
}
}
// Handle all dims inside of axis are 1-dim.
if (curr > static_cast<size_t>(new_axis) && input_shape[curr] == 1 && indices_shape[curr] == 1) {
--curr;
}
size_t new_rank = curr + 1;
args.rank = static_cast<int64_t>(new_rank);
args.axis = static_cast<int64_t>(new_axis);
input_shape.resize(new_rank);
indices_shape.resize(new_rank);
if (p_indices_strides) {
p_indices_strides->resize(new_rank);
}
// Set stride along axis to 0 so we don't need IF statement to check in kernel.
TensorPitches masked_input_strides_vec(input_shape);
args.input_stride_along_axis = masked_input_strides_vec[args.axis];
args.input_dim_along_axis = input_shape[args.axis];
masked_input_strides_vec[args.axis] = 0;
args.masked_input_strides = TArray<int64_t>(ToConstSpan(masked_input_strides_vec));
args.indices_fdms.SetSize(static_cast<int32_t>(new_rank));
TensorPitches indices_shape_strides(indices_shape);
for (int32_t i = 0; i < static_cast<int32_t>(new_rank); ++i) {
args.indices_fdms[i] = fast_divmod(gsl::narrow_cast<int>(indices_shape_strides[i]));
}
if (p_indices_strides) {
args.indices_strides = TArray<int64_t>(ToConstSpan(*p_indices_strides));
}
}
// GatherElementsGrad needs atomic_add which supports float types only, so use half, float and double for 16, 32, and 64
// bits data respectively.
ONNX_NAMESPACE::TensorProto_DataType GetElementType(size_t element_size) {
switch (element_size) {
case sizeof(int8_t):
return ONNX_NAMESPACE::TensorProto_DataType_INT8;
case sizeof(MLFloat16):
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
case sizeof(float):
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
case sizeof(double):
return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
// should not reach here as we validate if the all relevant types are supported in the Compute method
default:
return ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
}
}
#define CASE_GATHER_ELEMENTS_IMPL(type) \
case sizeof(type): { \
const type* indices_data = reinterpret_cast<const type*>(indices_data_raw); \
GatherElementsImpl(stream, input_data, indices_data, output_data, args); \
} break
template <typename T>
struct GatherElements::ComputeImpl {
Status operator()(hipStream_t stream, const void* input_data_raw, const void* indices_data_raw,
void* output_data_raw, const size_t index_element_size,
const GatherScatterElementsArgs& args) const {
typedef typename ToHipType<T>::MappedType HipT;
const HipT* input_data = reinterpret_cast<const HipT*>(input_data_raw);
HipT* output_data = reinterpret_cast<HipT*>(output_data_raw);
switch (index_element_size) {
CASE_GATHER_ELEMENTS_IMPL(int32_t);
CASE_GATHER_ELEMENTS_IMPL(int64_t);
// should not reach here as we validate if the all relevant types are supported in the Compute method
default:
ORT_THROW("Unsupported indices element size by the GatherElements ROCM kernel");
}
return Status::OK();
}
};
#undef CASE_GATHER_ELEMENTS_IMPL
Status GatherElements::ComputeInternal(OpKernelContext* context) const {
// Process input data tensor
const auto* input_tensor = context->Input<Tensor>(0);
const auto& input_shape = input_tensor->Shape();
const int64_t input_rank = static_cast<int64_t>(input_shape.NumDimensions());
// Process indices tensor
const auto* indices_tensor = context->Input<Tensor>(1);
const auto& indices_shape = indices_tensor->Shape();
const int64_t indices_size = indices_shape.Size();
// Handle negative axis if any
const int64_t axis = HandleNegativeAxis(axis_, input_rank);
// Validate input shapes and ranks (invoke the static method in the CPU GatherElements kernel that hosts the shared
// checks)
ORT_RETURN_IF_ERROR(onnxruntime::GatherElements::ValidateInputShapes(input_shape, indices_shape, axis));
// create output tensor
auto* output_tensor = context->Output(0, indices_shape);
// if there are no elements in 'indices' - nothing to process
if (indices_size == 0) return Status::OK();
GatherScatterElementsArgs args;
args.indices_size = indices_size;
TensorShapeVector input_shape_vec = input_shape.AsShapeVector();
TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector();
TensorShapeVector* p_indices_strides_vec = nullptr;
TensorShapeVector indices_strides_vec;
#ifdef ENABLE_TRAINING
if (!indices_tensor->IsContiguous()) {
indices_strides_vec = ToShapeVector(indices_tensor->Strides());
p_indices_strides_vec = &indices_strides_vec;
}
#endif
CoalesceDimensions(input_shape_vec, indices_shape_vec, p_indices_strides_vec, axis, args);
// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
int dtype = GetElementType(input_tensor->DataType()->Size());
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
ORT_THROW("Unsupported element size by the GatherElements ROCM kernel");
}
utils::MLTypeCallDispatcher<int8_t, MLFloat16, float, double> t_disp(dtype);
return t_disp.InvokeRet<Status, ComputeImpl>(Stream(), input_tensor->DataRaw(), indices_tensor->DataRaw(),
output_tensor->MutableDataRaw(), indices_tensor->DataType()->Size(),
args);
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/shared_library/provider_api.h"
namespace onnxruntime {
namespace rocm {
struct GatherScatterElementsArgs;
// Coalesce those contiguous axes that have same dim values for both input and indices (exclude the gather/scatter axis)
// so that we will have less divmod to compute during the kernels.
// For example:
// shape(input)=[2,2,2], shape(indices)=[2,2,3], axis=2 is same as shape(input)=[4,2], shape(indices)=[4,3], axis=1
// shape(input)=[2,1,2,2,3,2,2], shape(indices)=[2,1,2,2,2,2,2], axis=3) is same as
// shape(input)=[4,2,3,4],shape(indices)=[4,2,2,4], axis=1
// If indices is strided, dim i (outer) and dim j is contiguous when strides[i] = shape[j] * strides[j].
// For example:
// shape(indices)=[2,3,4,5], strides(indices)=[0,20,5,1], then dim-2 and dim-3 is contiguous (5==5*1),
// dim-1 and dim-2 is contiguous (20==4*5), but dim-0 and dim-1 is not contiguous (0!=3*20).
void CoalesceDimensions(TensorShapeVector& input_shape, TensorShapeVector& indices_shape,
TensorShapeVector* p_indices_strides, int64_t axis, GatherScatterElementsArgs& args);
ONNX_NAMESPACE::TensorProto_DataType GetElementType(size_t element_size);
class GatherElements final : public RocmKernel {
public:
GatherElements(const OpKernelInfo& info) : RocmKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK(), "Missing/Invalid 'axis' attribute value");
}
~GatherElements() = default;
Status ComputeInternal(OpKernelContext* context) const override;
private:
template <typename T>
struct ComputeImpl;
int64_t axis_;
};
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/tensor/gather_elements_impl.h"
#include "core/providers/rocm/tensor/scatter_elements_impl.h"
#ifdef ENABLE_TRAINING
#include "orttraining/training_ops/rocm/tensor/gather_elements_grad_impl.h"
#endif
#include "core/providers/rocm/atomic/common.cuh"
#include "core/providers/rocm/cu_inc/common.cuh"
namespace onnxruntime {
namespace rocm {
namespace {
constexpr int kThreadsPerBlock = GridDim::maxThreadsPerBlock;
constexpr int kThreadWorkSize = 4;
// General case to compute the input(for Gather)/output(for Scatter) and indices data offset given the thread ID
// using strides and fast_divmods. The offsets are returned in a 2-element TArray.
template <bool IsStridedIndices>
struct OffsetCalculator {
OffsetCalculator(const int rank, const TArray<int64_t> masked_input_strides, const TArray<fast_divmod> indices_fdms,
const TArray<int64_t> indices_strides)
: rank_(rank), indices_fdms_(indices_fdms) {
masked_input_strides_.SetSize(rank);
if (IsStridedIndices) indices_strides_.SetSize(rank);
for (int dim = 0; dim < rank; ++dim) {
masked_input_strides_[dim] = static_cast<HIP_LONG>(masked_input_strides[dim]);
if (IsStridedIndices) indices_strides_[dim] = static_cast<HIP_LONG>(indices_strides[dim]);
}
}
__device__ __forceinline__ TArray<HIP_LONG, 2> get(HIP_LONG linear_idx) const {
TArray<HIP_LONG, 2> offsets;
offsets[0] = 0;
offsets[1] = IsStridedIndices ? 0 : linear_idx;
HIP_LONG q, r = linear_idx;
#pragma unroll
for (int dim = 0; dim < indices_fdms_.Capacity(); ++dim) {
if (dim == rank_) break;
indices_fdms_[dim].divmod(r, q, r);
offsets[0] += masked_input_strides_[dim] * q;
if (IsStridedIndices) offsets[1] += indices_strides_[dim] * q;
}
return offsets;
}
int rank_;
TArray<fast_divmod> indices_fdms_;
TArray<HIP_LONG> masked_input_strides_;
TArray<HIP_LONG> indices_strides_;
};
// Optimization for 2D case to compute the input(for Gather)/output(for Scatter) and indices data offset
// given the thread ID so we don't need FOR loop for fast_divmod computes.
// The offsets are returned in a 2-element TArray.
template <bool IsOuterAxis, bool IsStridedIndices>
struct OffsetCalculatorFor2D {
OffsetCalculatorFor2D(const fast_divmod indices_row_size_fdm, const int64_t input_row_size,
const TArray<int64_t> indices_strides)
: indices_row_size_fdm_(indices_row_size_fdm), input_row_size_(static_cast<HIP_LONG>(input_row_size)) {
if (IsStridedIndices) {
indices_strides_.SetSize(2);
indices_strides_[0] = static_cast<HIP_LONG>(indices_strides[0]);
indices_strides_[1] = static_cast<HIP_LONG>(indices_strides[1]);
}
}
__device__ __forceinline__ TArray<HIP_LONG, 2> get(HIP_LONG linear_idx) const {
TArray<HIP_LONG, 2> offsets;
if (IsStridedIndices) {
HIP_LONG q, r = linear_idx;
indices_row_size_fdm_.divmod(r, q, r);
offsets[0] = IsOuterAxis ? r : q * input_row_size_;
offsets[1] = q * indices_strides_[0] + r * indices_strides_[1];
} else {
offsets[0] =
IsOuterAxis ? indices_row_size_fdm_.mod(linear_idx) : indices_row_size_fdm_.div(linear_idx) * input_row_size_;
offsets[1] = linear_idx;
}
return offsets;
}
fast_divmod indices_row_size_fdm_;
HIP_LONG input_row_size_;
TArray<HIP_LONG> indices_strides_;
};
} // namespace
template <class T>
struct FuncAssignment {
__device__ __inline__ void operator()(T* a, const T* b) const { *a = *b; }
};
template <typename T, typename TIndex, bool IsGather, typename OffsetCalcT, typename TFunc>
__global__ void _GatherScatterElementsKernel(const T* src_data, const TIndex* indices_data, T* output_data,
const int64_t input_dim_along_axis, const int64_t input_stride_along_axis,
const OffsetCalcT offset_calc, const TFunc& func, HIP_LONG N) {
HIP_LONG start = kThreadsPerBlock * kThreadWorkSize * blockIdx.x + threadIdx.x;
HIP_LONG id;
T value[kThreadWorkSize];
if (!IsGather) {
id = start;
#pragma unroll
for (int work = 0; work < kThreadWorkSize; ++work) {
if (id < N) {
value[work] = src_data[id];
id += kThreadsPerBlock;
}
}
}
id = start;
#pragma unroll
for (int work = 0; work < kThreadWorkSize; ++work) {
if (id < N) {
TArray<HIP_LONG, 2> offsets = offset_calc.get(id);
int64_t input_offset_along_axis = static_cast<int64_t>(indices_data[offsets[1]]);
if (input_offset_along_axis >= -input_dim_along_axis && input_offset_along_axis < input_dim_along_axis) {
if (input_offset_along_axis < 0) input_offset_along_axis += input_dim_along_axis;
HIP_LONG input_offset = offsets[0] + static_cast<HIP_LONG>(input_offset_along_axis * input_stride_along_axis);
if (IsGather) {
func(value + work, src_data + input_offset);
} else {
func(output_data + input_offset, value + work);
}
}
id += kThreadsPerBlock;
}
}
if (IsGather) {
id = start;
#pragma unroll
for (int work = 0; work < kThreadWorkSize; ++work) {
if (id < N) {
output_data[id] = value[work];
id += kThreadsPerBlock;
}
}
}
}
#define LAUNCH_GATHER_SCATTER_ELEMENTS_2D_KERNEL(src_data, is_outer_axis, is_strided_indices, is_gather) \
auto offset_calc = OffsetCalculatorFor2D<is_outer_axis, is_strided_indices>(args.indices_fdms[0], input_row_size, \
args.indices_strides); \
_GatherScatterElementsKernel<T, TIndex, is_gather, decltype(offset_calc), decltype(func)> \
<<<blocksPerGrid, kThreadsPerBlock, 0, stream>>>(src_data, indices_data, output_data, args.input_dim_along_axis, \
args.input_stride_along_axis, offset_calc, func, N)
#define LAUNCH_GATHER_SCATTER_ELEMENTS_KERNEL(src_data, is_strided_indices, is_gather) \
auto offset_calc = \
OffsetCalculator<is_strided_indices>(rank, args.masked_input_strides, args.indices_fdms, args.indices_strides); \
_GatherScatterElementsKernel<T, TIndex, is_gather, decltype(offset_calc), decltype(func)> \
<<<blocksPerGrid, kThreadsPerBlock, 0, stream>>>(src_data, indices_data, output_data, args.input_dim_along_axis, \
args.input_stride_along_axis, offset_calc, func, N)
#define HANDLE_GATHER_SCATTER_ELEMENTS_2D_IS_STRIDED_INDICES(src_data, is_outer_axis, is_gather) \
if (args.indices_strides.Size() > 0) { \
LAUNCH_GATHER_SCATTER_ELEMENTS_2D_KERNEL(src_data, is_outer_axis, true, is_gather); \
} else { \
LAUNCH_GATHER_SCATTER_ELEMENTS_2D_KERNEL(src_data, is_outer_axis, false, is_gather); \
}
template <typename T, typename TIndex>
void GatherElementsImpl(hipStream_t stream, const T* input_data, const TIndex* indices_data, T* output_data,
const GatherScatterElementsArgs& args) {
HIP_LONG N = static_cast<HIP_LONG>(args.indices_size);
int blocksPerGrid = static_cast<int>(CeilDiv(N, kThreadsPerBlock * kThreadWorkSize));
auto func = FuncAssignment<T>();
if (args.rank == 2) {
int64_t input_row_size = args.masked_input_strides[0];
if (args.axis == 0) {
HANDLE_GATHER_SCATTER_ELEMENTS_2D_IS_STRIDED_INDICES(input_data, true, true);
} else {
HANDLE_GATHER_SCATTER_ELEMENTS_2D_IS_STRIDED_INDICES(input_data, false, true);
}
return;
}
int rank = static_cast<int>(args.rank);
if (args.indices_strides.Size() > 0) {
LAUNCH_GATHER_SCATTER_ELEMENTS_KERNEL(input_data, true, true);
} else {
// Save one divmod in kernel if axis is the last dim.
if (args.rank == args.axis + 1) rank -= 1;
LAUNCH_GATHER_SCATTER_ELEMENTS_KERNEL(input_data, false, true);
}
}
template <typename T, typename TIndex, typename TFunc>
Status ScatterElementsImplInternal(hipStream_t stream, const T* input_data, const TIndex* indices_data,
const T* updates_data, T* output_data, const GatherScatterElementsArgs& args,
const TFunc& func) {
if (input_data != output_data) {
HIP_RETURN_IF_ERROR(
hipMemcpyAsync(output_data, input_data, args.input_size * sizeof(T), hipMemcpyDeviceToDevice, stream));
}
if (args.indices_size == 0) return Status::OK();
HIP_LONG N = static_cast<HIP_LONG>(args.indices_size);
int blocksPerGrid = static_cast<int>(CeilDiv(N, kThreadsPerBlock * kThreadWorkSize));
if (args.rank == 2) {
int64_t input_row_size = args.masked_input_strides[0];
if (args.axis == 0) {
HANDLE_GATHER_SCATTER_ELEMENTS_2D_IS_STRIDED_INDICES(updates_data, true, false);
} else {
HANDLE_GATHER_SCATTER_ELEMENTS_2D_IS_STRIDED_INDICES(updates_data, false, false);
}
return Status::OK();
}
int rank = static_cast<int>(args.rank);
if (args.indices_strides.Size() > 0) {
LAUNCH_GATHER_SCATTER_ELEMENTS_KERNEL(updates_data, true, false);
} else {
// Save one divmod in kernel if axis is the last dim.
if (args.rank == args.axis + 1) rank -= 1;
LAUNCH_GATHER_SCATTER_ELEMENTS_KERNEL(updates_data, false, false);
}
return Status::OK();
}
#undef HANDLE_GATHER_SCATTER_ELEMENTS_2D_IS_STRIDED_INDICES
#undef LAUNCH_GATHER_SCATTER_ELEMENTS_KERNEL
#undef LAUNCH_GATHER_SCATTER_ELEMENTS_2D_KERNEL
template <typename T, typename TIndex>
Status ScatterElementsImpl(hipStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data,
T* output_data, const GatherScatterElementsArgs& args) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncAssignment<T>());
}
#define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \
template void GatherElementsImpl<T, TIndex>(hipStream_t stream, const T* input_data, const TIndex* indices_data, \
T* output_data, const GatherScatterElementsArgs& args); \
template Status ScatterElementsImpl<T, TIndex>(hipStream_t stream, const T* input_data, const TIndex* indices_data, \
const T* updates_data, T* output_data, \
const GatherScatterElementsArgs& args);
#define GATHER_SCATTER_ELEMENTS_SPECIALIZED_IMPL(T) \
GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, int32_t) \
GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, int64_t)
// GatherElementsGrad needs atomic_add which supports float types only, so use half, float and double for 16, 32, and 64
// bits data respectively.
GATHER_SCATTER_ELEMENTS_SPECIALIZED_IMPL(int8_t)
GATHER_SCATTER_ELEMENTS_SPECIALIZED_IMPL(half)
GATHER_SCATTER_ELEMENTS_SPECIALIZED_IMPL(float)
GATHER_SCATTER_ELEMENTS_SPECIALIZED_IMPL(double)
#undef GATHER_SCATTER_ELEMENTS_SPECIALIZED_IMPL
#undef GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL
#ifdef ENABLE_TRAINING
template <class T>
struct FuncAtomicAdd {
__device__ __inline__ void operator()(T* a, const T* b) const { atomic_add(a, *b); }
};
template <typename T, typename TIndex>
Status GatherElementsGradImpl(hipStream_t stream, const TIndex* indices_data, const T* updates_data, T* output_data,
const GatherScatterElementsArgs& args) {
// Give output_data as the input_data parameter by intention,
// to skip input_data copy, which is not applicable for GatherElementsGrad.
return ScatterElementsImplInternal(stream, output_data, indices_data, updates_data, output_data, args,
FuncAtomicAdd<T>());
}
#define GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, TIndex) \
template Status GatherElementsGradImpl<T, TIndex>(hipStream_t stream, const TIndex* indices_data, \
const T* updates_data, T* output_data, \
const GatherScatterElementsArgs& args);
#define GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(T) \
GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, int32_t) \
GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL(T, int64_t)
GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(half)
GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(float)
GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL(double)
#undef GATHER_ELEMENTS_GRAD_SPECIALIZED_SCATTER_ADD_IMPL
#undef GATHER_ELEMENTS_GRAD_SPECIALIZED_TINDEX_IMPL
#endif
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace rocm {
struct GatherScatterElementsArgs {
int64_t rank;
int64_t axis;
int64_t input_size;
int64_t input_dim_along_axis;
int64_t input_stride_along_axis;
TArray<int64_t> masked_input_strides;
TArray<fast_divmod> indices_fdms;
TArray<int64_t> indices_strides;
int64_t indices_size;
};
template <typename T, typename TIndex>
void GatherElementsImpl(hipStream_t stream, const T* input_data, const TIndex* indices_data, T* output_data,
const GatherScatterElementsArgs& args);
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/cu_inc/common.cuh"
#include "gather_impl.h"
namespace onnxruntime {
namespace rocm {
__host__ __device__ inline int64_t GetIndexValue(const void* index_data, size_t index_element_size, size_t offset) {
switch (index_element_size) {
case sizeof(int32_t):
return *(reinterpret_cast<const int32_t*>(index_data) + offset);
break;
case sizeof(int64_t):
return *(reinterpret_cast<const int64_t*>(index_data) + offset);
break;
default:
break;
}
// What is a sensible thing to do here?
assert(false);
return std::numeric_limits<int64_t>::max();
}
template <typename T>
__global__ void _GatherKernel(
const int64_t input_block_size,
const int64_t indices_max,
const fast_divmod output_block_size,
const fast_divmod block_size,
const void* indices_data,
const size_t index_element_size,
const T* input_data,
T* output_data,
const HIP_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
HIP_LONG input_index = 0;
int input_block_index, block_offset;
output_block_size.divmod(id, input_block_index, block_offset);
int indices_index, offset;
block_size.divmod(block_offset, indices_index, offset);
int64_t idx = GetIndexValue(indices_data, index_element_size, indices_index);
idx = idx < 0 ? idx + indices_max : idx;
if (idx < 0 || idx >= indices_max) {
output_data[id] = 0;
return;
}
input_index = input_block_index * input_block_size + idx * block_size.d_ + offset;
output_data[id] = input_data[input_index];
}
void GatherImpl(
hipStream_t stream,
const int64_t input_block_size,
const int64_t indices_max,
const fast_divmod& output_block_size,
const fast_divmod& block_size,
const void* indices_data,
size_t index_element_size,
const void* input_data,
size_t element_size,
void* output_data,
const size_t N) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
switch (element_size) {
case sizeof(int8_t): {
using HipType = typename ToHipType<int8_t>::MappedType;
hipLaunchKernelGGL(_GatherKernel, blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
reinterpret_cast<const HipType*>(input_data), reinterpret_cast<HipType*>(output_data), (HIP_LONG)N);
} break;
case sizeof(int16_t): {
using HipType = typename ToHipType<int16_t>::MappedType;
hipLaunchKernelGGL(_GatherKernel, blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
reinterpret_cast<const HipType*>(input_data), reinterpret_cast<HipType*>(output_data), (HIP_LONG)N);
} break;
case sizeof(int32_t): {
using HipType = typename ToHipType<int32_t>::MappedType;
hipLaunchKernelGGL(_GatherKernel, blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
reinterpret_cast<const HipType*>(input_data), reinterpret_cast<HipType*>(output_data), (HIP_LONG)N);
} break;
case sizeof(int64_t): {
using HipType = typename ToHipType<int64_t>::MappedType;
hipLaunchKernelGGL(_GatherKernel, blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,
reinterpret_cast<const HipType*>(input_data), reinterpret_cast<HipType*>(output_data), (HIP_LONG)N);
} break;
default:
ORT_THROW("Unsupported element size by the Gather ROCM kernel");
}
}
} // namespace rocm
} // namespace onnxruntime
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment