Commit 1a91fcc2 authored by gaoqiong's avatar gaoqiong
Browse files

add dtk所需文件

parent a144865d
Pipeline #492 failed with stages
in 0 seconds
/*
Copyright (c) NVIDIA Corporation and Microsoft Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#pragma once
#include "core/common/common.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Launch the softmax kernels that does not use compact memory.
Status LaunchLongformerSoftmaxSimpleKernel(
hipStream_t stream,
rocblas_handle rocblas,
void* workspace, // softmax space
const void* q, // transposed Q with shape (B, N, S, H)
const void* k, // transposed K with shape (B, N, S, H)
const void* v, // transposed V with shape (B, N, S, H)
const void* attention_mask, // attention mask with shape (B, S), with value 0.0 not masked, and -10000.0 masked.
const void* global_q, // Q for global tokens with shape (B, N, S, H)
const void* global_k, // K for global tokens with shape (B, N, S, H)
const void* global_v, // V for global tokens with shape (B, N, S, H)
const int* global_attention, // global attention flags with shape (B, S), with value 0 for local and 1 for global.
const int* global_index, // Global index with shape (B, S)
const int* batch_global_num, // Number of global tokens per batch with shape (B, 1)
void* pinned_buffer, // Pinned memory in CPU. Number of global tokens per batch with shape (B, 1)
void* output, // output with shape (B, N, S, H)
float scaler, // scalar
int batch_size, // batch size
int sequence_length, // sequence length
int num_heads, // number of heads
int head_size, // hidden size per head
int attention_window, // one sided windows size
size_t element_size);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
Copyright (c) NVIDIA Corporation and Microsoft Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <hipcub/hipcub.hpp>
#include <hipcub/device/device_partition.hpp>
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "longformer_global_impl.h"
using namespace onnxruntime::rocm;
using namespace hipcub;
namespace onnxruntime {
namespace contrib {
namespace rocm {
size_t GetGlobalScratchSize(int sequence_length) {
// Global Index scratch layout:
// [sequence_index: int S][tmp_storage: int 1024x1]
return sizeof(int) * (sequence_length + 1024);
}
__global__ void InitSequenceIndexKernel(int* sequence_index, int sequence_length) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < sequence_length; i += blockDim.x) {
sequence_index[i] = i;
}
}
Status BuildGlobalIndex(
const hipDeviceProp_t& device_prop,
hipStream_t stream,
const int* global_attention,
int batch_size,
int sequence_length,
int* global_index,
int* batch_global_num,
void* scratch,
size_t scratch_size) {
int* sequence_index = (int*)scratch;
int* tmp_storage = sequence_index + sequence_length;
const int threads = device_prop.maxThreadsPerBlock;
int blocks = CeilDiv(sequence_length, threads);
hipLaunchKernelGGL(InitSequenceIndexKernel, blocks, threads, 0, stream, sequence_index, sequence_length);
// Determine temporary device storage size.
// For int* inputs/outputs, it need 767 bytes. We reserved 1024*4 bytes, which shall be enough.
size_t temp_storage_bytes = 0;
HIP_RETURN_IF_ERROR(hipcub::DevicePartition::Flagged(
NULL, temp_storage_bytes, sequence_index,
global_attention, global_index, batch_global_num, sequence_length, stream));
if (temp_storage_bytes + sizeof(int) * sequence_length > scratch_size) {
ORT_THROW("LongformerAttention scratch space is not large enough. Temp storage bytes are", temp_storage_bytes);
}
// Find the global attention indices and number of global attention tokens
for (int i = 0; i < batch_size; ++i) {
HIP_RETURN_IF_ERROR(hipcub::DevicePartition::Flagged(
reinterpret_cast<void*>(tmp_storage), temp_storage_bytes, sequence_index,
global_attention + i * sequence_length, global_index + i * sequence_length,
batch_global_num + i, sequence_length, stream));
}
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Size of global Index scratch in bytes.
size_t GetGlobalScratchSize(int sequence_length);
// Find the global attention indices and number of global attention tokens
Status BuildGlobalIndex(
const hipDeviceProp_t& device_prop,
hipStream_t stream,
const int* global_attention,
int batch_size,
int sequence_length,
int* global_index,
int* batch_global_num,
void* scratch,
size_t scratch_size);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "ngram_repeat_block.h"
#include "ngram_repeat_block_impl.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(
NGramRepeatBlock,
kMSDomain,
1,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("Tid", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
NGramRepeatBlock);
using namespace ONNX_NAMESPACE;
NGramRepeatBlock::NGramRepeatBlock(const OpKernelInfo& info) : RocmKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("ngram_size", &ngram_size_).IsOK());
ORT_ENFORCE(ngram_size_ > 0);
}
Status NGramRepeatBlock::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_ids = context->Input<Tensor>(0);
const Tensor* scores = context->Input<Tensor>(1);
Tensor* output = context->Output(0, scores->Shape());
const auto* scores_source = static_cast<const float*>(scores->DataRaw());
auto* scores_target = static_cast<float*>(output->MutableDataRaw());
if (scores_source != scores_target) {
HIP_RETURN_IF_ERROR(hipMemcpyAsync(scores_target, scores_source, scores->Shape().Size() * sizeof(float), hipMemcpyDeviceToDevice, Stream()));
}
const auto& input_ids_dims = input_ids->Shape().GetDims();
const auto& scores_dims = scores->Shape().GetDims();
ORT_ENFORCE(input_ids_dims.size() == 2);
ORT_ENFORCE(scores_dims.size() == 2);
int64_t batch_size = input_ids_dims[0];
int64_t cur_len = input_ids_dims[1];
ORT_ENFORCE(scores_dims[0] == batch_size);
int64_t vocab_size = scores_dims[1];
if (cur_len + 1 < ngram_size_) {
return Status::OK();
}
const auto* input_ids_data = static_cast<const int64_t*>(input_ids->DataRaw(input_ids->DataType()));
NGramRepeatBlockImpl(
Stream(),
input_ids_data,
scores_target,
gsl::narrow_cast<int>(batch_size),
gsl::narrow_cast<int>(cur_len - 1),
gsl::narrow_cast<int>(cur_len),
gsl::narrow_cast<int>(vocab_size),
gsl::narrow_cast<int>(1),
gsl::narrow_cast<int>(ngram_size_));
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
class NGramRepeatBlock final : public RocmKernel {
public:
NGramRepeatBlock(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
int64_t ngram_size_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
/*
Kernel implementation for blocking repeated n-grams.
*/
#include "core/providers/rocm/cu_inc/common.cuh"
#include "contrib_ops/rocm/bert/ngram_repeat_block_impl.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
// Ban repeated ngrams of length = 'no_repeat_ngram_size'
__global__ void banRepeatedTokens(const int64_t* __restrict__ tokens,
float* __restrict__ lprobs,
int max_predict_len, int vocab_size,
int no_repeat_ngram_size) {
auto row = blockIdx.x;
auto col = threadIdx.x;
auto start = row * (max_predict_len) + col;
// Each thread compares ngram starting from
// thread index with final ngram starting from
// step - no_repeat_ngram_size +2
auto check_start_pos = blockDim.x;
auto lprob_start = row * vocab_size;
bool is_banned = true;
extern __shared__ int64_t tokens_shm[];
tokens_shm[col] = tokens[start];
if (col == blockDim.x - 1) {
for (int i = 1; i < no_repeat_ngram_size; i++) {
if (col + i < max_predict_len) {
tokens_shm[col + i] = tokens[start + i];
}
}
}
__syncthreads();
for (int k = 0; k < no_repeat_ngram_size - 1; k++) {
if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) {
is_banned = false;
}
}
if (is_banned == true) {
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
lprobs[lprob_start + token_to_be_banned] = -INFINITY;
}
}
// Allocate blocks and threads based on
// batch size and sequence length and launch
// kernel
void NGramRepeatBlockImpl(
hipStream_t stream,
const int64_t* tokens_ptr,
float* scores_ptr,
int bsz,
int step,
int max_predict_len,
int vocab_size,
int beam_size,
int no_repeat_ngram_size) {
int threads = step - no_repeat_ngram_size + 2;
if (threads <= 0) return;
int blocks = bsz * beam_size;
int shared_mem_size = (step + 1) * sizeof(int64_t);
// Launching N blocks where N is number of samples in a batch (beams*bsz)
// Launching T threads where T is number of previous ngrams in a sample
// Allocating shared mem per block for fastser access of input tokens since
// each token will be accessed N times to compare with current Ngram where
// N is Ngram size.
hipLaunchKernelGGL(banRepeatedTokens, blocks, threads, shared_mem_size, stream,
tokens_ptr, scores_ptr, max_predict_len, vocab_size, no_repeat_ngram_size);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
void NGramRepeatBlockImpl(
hipStream_t stream,
const int64_t* tokens_ptr,
float* scores_ptr,
int bsz,
int step,
int max_predict_len,
int vocab_size,
int beam_size,
int no_repeat_ngram_size);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/remove_padding.h"
#include "contrib_ops/rocm/bert/bert_padding.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
RemovePadding, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.OutputMemoryType(OrtMemTypeCPUOutput, 3) /*max_token_count on CPU*/ \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
RemovePadding<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
using namespace ONNX_NAMESPACE;
template <typename T>
RemovePadding<T>::RemovePadding(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
}
template <typename T>
Status RemovePadding<T>::ComputeInternal(OpKernelContext* context) const {
// shape of inputs:
// input: (batch_size, sequence_length, hidden_size)
// sequence_token_count: (batch_size)
// shape of outputs:
// output: (total_tokens, hidden_size)
// token_offset: (batch_size, sequence_length)
// cumulated_seq_len: (batch_size + 1)
// max_token_count: (1)
const Tensor* input = context->Input<Tensor>(0);
const Tensor* sequence_token_count = context->Input<Tensor>(1);
const auto& dims = input->Shape().GetDims();
if (dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ",
dims.size());
}
int64_t batch_size = dims[0];
int64_t sequence_length = dims[1];
int64_t hidden_size = dims[2];
auto token_count_buffer = GetScratchBuffer<int>(2);
TensorShapeVector token_offset_shape(2);
token_offset_shape[0] = batch_size;
token_offset_shape[1] = sequence_length;
Tensor* token_offset = context->Output(1, token_offset_shape);
TensorShapeVector cumulated_seq_len_shape(1);
cumulated_seq_len_shape[0] = batch_size + static_cast<int64_t>(1);
Tensor* cumulated_seq_len = context->Output(2, cumulated_seq_len_shape);
LaunchGetTokenOffset(token_count_buffer.get(),
token_offset->MutableData<int>(),
cumulated_seq_len->MutableData<int>(),
sequence_token_count->Data<int>(),
static_cast<int>(batch_size),
static_cast<int>(sequence_length),
Stream());
HIP_RETURN_IF_ERROR(hipGetLastError());
// Copy token_count to CPU
auto pinned_buffer = AllocateBufferOnCPUPinned<int>(2);
int* token_count_pinned = pinned_buffer.get();
HIP_RETURN_IF_ERROR(hipMemcpyAsync(token_count_pinned,
token_count_buffer.get(),
sizeof(int) * 2,
hipMemcpyDeviceToHost,
Stream()));
// Wait until token_count is copied to host.
HIP_RETURN_IF_ERROR(hipStreamSynchronize(Stream()));
int total_token_count = token_count_pinned[0];
int max_token_count = token_count_pinned[1];
TensorShapeVector output_shape(2);
output_shape[0] = static_cast<int64_t>(total_token_count);
output_shape[1] = hidden_size;
Tensor* output = context->Output(0, output_shape);
TensorShapeVector max_token_count_shape(1);
max_token_count_shape[0] = 1;
Tensor* max_token_count_tensor = context->Output(3, max_token_count_shape);
max_token_count_tensor->MutableData<int>()[0] = max_token_count;
typedef typename ToHipType<T>::MappedType HipT;
LaunchRemovePadding<HipT>(
reinterpret_cast<HipT*>(output->MutableData<T>()),
reinterpret_cast<const HipT*>(input->Data<T>()),
token_offset->Data<int>(),
total_token_count,
static_cast<int>(hidden_size),
Stream());
HIP_RETURN_IF_ERROR(hipGetLastError());
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
class RemovePadding final : public RocmKernel {
public:
RemovePadding(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/restore_padding.h"
#include "contrib_ops/rocm/bert/bert_padding.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
RestorePadding, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
RestorePadding<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
using namespace ONNX_NAMESPACE;
template <typename T>
RestorePadding<T>::RestorePadding(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
}
template <typename T>
Status RestorePadding<T>::ComputeInternal(OpKernelContext* context) const {
// shape of inputs:
// input: (total_tokens, hidden_size)
// token_offset: (batch_size, sequence_length)
// shape of outputs:
// output: (batch_size, sequence_length, hidden_size)
const Tensor* input = context->Input<Tensor>(0);
const Tensor* token_offset = context->Input<Tensor>(1);
const auto& dims = input->Shape().GetDims();
if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 2 dimensions, got ",
dims.size());
}
int64_t total_tokens = dims[0];
int64_t hidden_size = dims[1];
const auto& token_offset_dims = token_offset->Shape().GetDims();
if (token_offset_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'token_offset' is expected to have 2 dimensions, got ",
token_offset_dims.size());
}
int64_t batch_size = token_offset_dims[0];
int64_t sequence_length = token_offset_dims[1];
TensorShapeVector output_shape(3);
output_shape[0] = batch_size;
output_shape[1] = sequence_length;
output_shape[2] = hidden_size;
Tensor* output = context->Output(0, output_shape);
typedef typename ToHipType<T>::MappedType HipT;
LaunchRestorePadding<HipT>(
reinterpret_cast<HipT*>(output->MutableData<T>()),
reinterpret_cast<const HipT*>(input->Data<T>()),
token_offset->Data<int>(),
static_cast<int>(total_tokens),
static_cast<int>(hidden_size),
static_cast<int>(batch_size),
static_cast<int>(sequence_length),
Stream());
HIP_RETURN_IF_ERROR(hipGetLastError());
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
class RestorePadding final : public RocmKernel {
public:
RestorePadding(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_common.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
// A wrapper class of hipEvent_t to destroy the event automatically for avoiding memory leak.
class AutoDestoryCudaEvent {
public:
AutoDestoryCudaEvent() : rocm_event_(nullptr) {
}
~AutoDestoryCudaEvent() {
if (rocm_event_ != nullptr)
(void)hipEventDestroy(rocm_event_);
}
hipEvent_t& Get() {
return rocm_event_;
}
private:
hipEvent_t rocm_event_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "grid_sample.h"
#include "grid_sample_impl.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GridSample, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
GridSample<T>);
REGISTER_KERNEL_TYPED(float)
template <typename T>
GridSample<T>::GridSample(const OpKernelInfo& info) : RocmKernel(info) {
std::string mode_str = info.GetAttrOrDefault<std::string>("mode", "bilinear");
std::string padding_mode_str = info.GetAttrOrDefault<std::string>("padding_mode", "zeros");
align_corners_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("align_corners", 0));
ORT_ENFORCE(mode_str == "bilinear" || mode_str == "nearest" || mode_str == "bicubic",
"mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic");
ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection",
"padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection");
if (mode_str == "bicubic") {
mode_i_ = 2;
} else if (mode_str == "nearest") {
mode_i_ = 1;
} else {
mode_i_ = 0;
}
if (padding_mode_str == "reflection") {
padding_mode_i_ = 2;
} else if (padding_mode_str == "border") {
padding_mode_i_ = 1;
} else {
padding_mode_i_ = 0;
}
}
template <typename T>
Status GridSample<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const auto& dims_input = X->Shape().GetDims();
const Tensor* Grid = context->Input<Tensor>(1);
const auto& dims_grid = Grid->Shape().GetDims();
if (dims_input.size() != 4 || dims_grid.size() != 4) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported");
}
ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]);
ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2");
TensorShapeVector dims_output(4);
dims_output[0] = dims_input[0];
dims_output[1] = dims_input[1];
dims_output[2] = dims_grid[1];
dims_output[3] = dims_grid[2];
Tensor* Y = context->Output(0, dims_output);
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
return Status::OK();
}
typedef typename ToHipType<T>::MappedType HipT;
HipT* Y_data = reinterpret_cast<HipT*>(Y->MutableData<T>());
GridSampleImpl<HipT>(
Stream(),
reinterpret_cast<const HipT*>(X->Data<T>()),
reinterpret_cast<const HipT*>(Grid->Data<T>()),
mode_i_,
padding_mode_i_,
align_corners_,
dims_input.data(),
dims_grid[1],
dims_grid[2],
Y_data);
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
class GridSample final : public RocmKernel {
public:
explicit GridSample(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;
private:
int64_t mode_i_; // 0: bilinear (default), 1: nearest 2: bicubic
int64_t padding_mode_i_; // 0:'zeros', 1: 'border', 2:'reflection'
int64_t align_corners_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/cu_inc/common.cuh"
#include "grid_sample_impl.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
__device__ T GsDenormalize(T n, int64_t length, bool align_corners) {
T x = {};
if (align_corners) { // align_corners: true => [-1, 1] to [0, length - 1]
x = (n + static_cast<T>(1)) / static_cast<T>(2) * (length - 1);
} else { // align_corners: false => [-1, 1] to [-0.5, length - 0.5]
x = ((n + static_cast<T>(1)) * length - static_cast<T>(1)) / static_cast<T>(2);
}
return x;
}
template <typename T>
__device__ T GsReflect(T x, float x_min, float x_max) {
float fx = static_cast<float>(x);
float dx = {};
float range = x_max - x_min;
if (fx < x_min) {
dx = x_min - fx;
int n = static_cast<int>(dx / range);
float r = dx - n * range;
if (n % 2 == 0) {
fx = x_min + r;
} else {
fx = x_max - r;
}
} else if (fx > x_max) {
dx = fx - x_max;
int n = static_cast<int>(dx / range);
float r = dx - n * range;
if (n % 2 == 0) {
fx = x_max - r;
} else {
fx = x_min + r;
}
}
// else fallthrough
return static_cast<T>(fx);
}
template <typename T>
__device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x,
int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
T pixel = 0.0f;
if (padding_mode == 0) { // zeros
if (x >= 0 && x < W && y >= 0 && y < H) {
pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
}
} else if (padding_mode == 1) { //border
x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x));
y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y));
pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
} else { // Reflection
x = (int64_t) GsReflect<T>(x, border[0], border[2]);
y = (int64_t) GsReflect<T>(y, border[1], border[3]);
pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
}
return pixel;
}
__device__ void GsGetCubicCoeffs(float x, float coeffs[4])
{
float cubic_alpha = -0.75f;
x = abs(x);
coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha);
coeffs[1] = (((cubic_alpha + 2) * x - (cubic_alpha + 3)) * x * x + 1);
coeffs[2] = (((cubic_alpha + 2) * (1 - x) - (cubic_alpha + 3)) * (1 - x) * (1 - x) + 1);
coeffs[3] = (((cubic_alpha * (2 - x) - 5 * cubic_alpha) * (2 - x) + 8 * cubic_alpha) * (2 - x) - 4 * cubic_alpha);
}
template <typename T>
__device__ T GsBicubicInterpolate(T p[4][4], float x, float y) {
float v[4] = {};
float coeffs[4] = {};
GsGetCubicCoeffs(x, coeffs);
for (int64_t i = 0; i < 4; i++) {
v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3];
}
GsGetCubicCoeffs(y, coeffs);
T pixel = static_cast<T>(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]);
return pixel;
}
template <typename T>
__global__ void _GridSampleKernel(
const T* input_data,
const T* grid_data,
const int64_t mode,
const int64_t padding_mode,
const int64_t align_corners,
const int64_t N,
const int64_t C,
const int64_t H_in,
const int64_t W_in,
const int64_t H_out,
const int64_t W_out,
T* output_data)
{
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out);
// extract batch index, channel index, y index, x index for current thread
int BIdx = idx / (C * H_out * W_out );
int tmpBCnt = BIdx * (C * H_out * W_out);
int cIdx = (idx - tmpBCnt) / (H_out * W_out);
int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);
int yIdx = (idx - tmpCCnt) / W_out;
int tmpHCnt = tmpCCnt + yIdx * W_out;
int xIdx = (idx - tmpHCnt);
int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx;
T grid_X = grid_data[grid_idx * 2 + 0];
T grid_Y = grid_data[grid_idx * 2 + 1];
int outIdx = idx;
T grid_x_imgSpace = GsDenormalize(grid_X, W_in, align_corners == 1);
T grid_y_imgSpace = GsDenormalize(grid_Y, H_in, align_corners == 1);
if (mode == 1) { //nearest
grid_x_imgSpace = nearbyint(grid_x_imgSpace);
grid_y_imgSpace = nearbyint(grid_y_imgSpace);
}
float x_min = -0.5f;
float x_max = W_in - 0.5f;
float y_min = -0.5f;
float y_max = H_in - 0.5f;
if (align_corners) {
x_min = 0.0f;
x_max = W_in - 1.0;
y_min = 0.0f;
y_max = H_in - 1.0f;
}
float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b
if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max ||
grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound
if (padding_mode == 1) { // border
grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
} else if (padding_mode == 2) { // reflection
grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max);
grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max);
}
}
if (mode == 0) { // bilinear
int x1 = floor(grid_x_imgSpace);
int y1 = floor(grid_y_imgSpace);
int x2 = x1 + 1;
int y2 = y1 + 1;
T w_lt = 0.0f;
T w_rt = 0.0f;
T w_lb = 0.0f;
T w_rb = 0.0f;
T w_r = grid_x_imgSpace - x1;
T w_l = 1.0f - w_r;
T w_b = grid_y_imgSpace - y1;
T w_t = 1.0f - w_b;
w_lt = w_t * w_l;
w_rt = w_t * w_r;
w_lb = w_b * w_l;
w_rb = w_b * w_r;
T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v;
output_data[outIdx] = interpoV;
return;
}
if (mode == 1) { // nearest
int x_n = grid_x_imgSpace;
int y_n = grid_y_imgSpace;
output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
return;
}
if (mode == 2) { // bicubic
int64_t x0 = static_cast<int64_t>(std::floor(grid_x_imgSpace)) - 1; // top-left corner of the bbox
int64_t y0 = static_cast<int64_t>(std::floor(grid_y_imgSpace)) - 1;
T p[4][4] = {}; // [H][W]
for (int64_t h = 0; h < 4; h++) {
for (int64_t w = 0; w < 4; w++) {
p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
}
}
T dx = grid_x_imgSpace - x0 - 1;
T dy = grid_y_imgSpace - y0 - 1;
output_data[outIdx] = GsBicubicInterpolate(p, dx, dy);
}
}
template <typename T>
void GridSampleImpl(
hipStream_t stream,
const T* input_data,
const T* grid_data,
const int64_t mode,
const int64_t padding_mode,
const int64_t align_corners,
const int64_t dims[4],
const int64_t H_out,
const int64_t W_out,
T* output_data) {
int blocksPerGrid = (int)(ceil(static_cast<T>(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock));
hipLaunchKernelGGL(HIP_KERNEL_NAME(_GridSampleKernel<T>), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data);
}
#define SPECIALIZED_IMPL(T) \
template void GridSampleImpl<T>(hipStream_t stream, const T* input_data, const T* grid_data, \
const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);
SPECIALIZED_IMPL(float)
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
void GridSampleImpl(
hipStream_t stream,
const T* input_data,
const T* grid_data,
const int64_t mode,
const int64_t padding_mode,
const int64_t align_corners,
const int64_t dims_input[4],
const int64_t H_out,
const int64_t W_out,
T* output_data);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/nn/layer_norm.h"
#include "core/providers/rocm/rocm_common.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
// LayerNormalization is an official ONNX operator in opset 17.
#define REGISTER_KERNEL_TYPED(T, U, V) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(LayerNormalization, kOnnxDomain, 1, 16, T##_##U##_##V, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("U", DataTypeImpl::GetTensorType<U>()) \
.TypeConstraint("V", DataTypeImpl::GetTensorType<V>()), \
onnxruntime::rocm::LayerNorm<T, U, V, false>); \
ONNX_OPERATOR_TYPED_KERNEL_EX(SimplifiedLayerNormalization, kOnnxDomain, 1, T##_##U##_##V, kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("U", DataTypeImpl::GetTensorType<U>()) \
.TypeConstraint("V", DataTypeImpl::GetTensorType<V>()), \
onnxruntime::rocm::LayerNorm<T, U, V, true>);
REGISTER_KERNEL_TYPED(float, float, float)
REGISTER_KERNEL_TYPED(double, double, double)
REGISTER_KERNEL_TYPED(MLFloat16, float, MLFloat16)
REGISTER_KERNEL_TYPED(float, float, MLFloat16)
REGISTER_KERNEL_TYPED(MLFloat16, float, float)
REGISTER_KERNEL_TYPED(BFloat16, float, BFloat16)
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/math/bias_dropout.h"
#include "core/providers/common.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
namespace {
template <typename T>
struct GetRatioDataImpl {
void operator()(const Tensor* ratio, float& ratio_data) const {
ratio_data = static_cast<float>(*(ratio->Data<T>()));
ORT_ENFORCE(ratio_data >= 0.0f && ratio_data < 1.0f, "ratio_data is outside range [0, 1)");
}
};
template <typename T>
struct BiasDropoutComputeImpl {
Status operator()(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const int64_t mask_element_count,
const fast_divmod fdm_dim, const float ratio_data, PhiloxGenerator& generator, const Tensor& X,
const Tensor& bias, const Tensor* residual, Tensor& Y, void* mask_data, bool has_same_shape_bias,
bool use_bitmask) const {
typedef typename ToHipType<T>::MappedType HipT;
const HipT* X_data = reinterpret_cast<const HipT*>(X.Data<T>());
const HipT* bias_data = reinterpret_cast<const HipT*>(bias.Data<T>());
const HipT* residual_data = nullptr;
if (residual) {
if (residual->Shape() != X.Shape()) {
return Status(common::ONNXRUNTIME, common::FAIL, "Residual input shape does not match X input shape.");
}
residual_data = reinterpret_cast<const HipT*>(residual->Data<T>());
}
HipT* Y_data = reinterpret_cast<HipT*>(Y.MutableData<T>());
BiasDropoutKernelImpl<HipT>(prop, stream, N, mask_element_count, fdm_dim, ratio_data, generator, X_data, bias_data,
residual_data, Y_data, mask_data, has_same_shape_bias, use_bitmask);
return Status::OK();
}
};
} // namespace
ONNX_OPERATOR_KERNEL_EX(BiasDropout, kMSDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>())
.InputMemoryType(OrtMemTypeCPUInput, 3)
.InputMemoryType(OrtMemTypeCPUInput, 4),
BiasDropout<false>);
ONNX_OPERATOR_KERNEL_EX(BitmaskBiasDropout, kMSDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<BitmaskElementType>())
.InputMemoryType(OrtMemTypeCPUInput, 3)
.InputMemoryType(OrtMemTypeCPUInput, 4),
BiasDropout<true>);
template <bool UseBitmask>
Status BiasDropout<UseBitmask>::ComputeInternal(OpKernelContext* context) const {
// Get X_data
const Tensor* X = context->Input<Tensor>(0);
ORT_RETURN_IF_NOT(X, "X Input is not available.");
const TensorShape& x_shape = X->Shape();
const int64_t N = x_shape.Size();
// Get bias_data
const Tensor* bias = context->Input<Tensor>(1);
if (!bias) return Status(common::ONNXRUNTIME, common::FAIL, "Bias input of BiasDropout is not available.");
const TensorShape& bias_shape = bias->Shape();
const int64_t dim = bias_shape.GetDims().back();
bool has_same_shape_bias = (bias_shape == x_shape);
if (!has_same_shape_bias) {
if (bias_shape.NumDimensions() != 1) {
return Status(common::ONNXRUNTIME, common::FAIL, "Bias input is not a 1D tensor.");
}
if (dim != x_shape.GetDims().back()) {
return Status(common::ONNXRUNTIME, common::FAIL, "Bias' dimension doesn't match input's last dimension.");
}
}
// Get residual_data
const Tensor* residual = context->Input<Tensor>(2);
// Get Y_data
auto Y = context->Output(0, x_shape);
// Get mask_data
Tensor* mask = nullptr;
int64_t mask_element_count = N;
if (UseBitmask) {
mask_element_count = (N + kNumBitsPerBitmaskElement - 1) / kNumBitsPerBitmaskElement;
mask = context->Output(1, {mask_element_count});
} else {
mask = context->Output(1, x_shape);
}
// Get the ratio_data
float ratio_data = default_ratio_;
auto ratio = context->Input<Tensor>(3);
if (ratio) {
utils::MLTypeCallDispatcher<float, MLFloat16, double, BFloat16> t_disp(ratio->GetElementType());
t_disp.Invoke<GetRatioDataImpl>(ratio, ratio_data);
}
// Check for inference mode.
const Tensor* training_mode = context->Input<Tensor>(4);
bool is_training_mode = training_mode && *(training_mode->Data<bool>());
if (!is_training_mode) {
ratio_data = 0.0f;
}
IAllocatorUniquePtr<void> temp_mask_buffer{}; // buffer to use if mask is not provided
void* const mask_data = [this, mask_element_count, mask, &temp_mask_buffer]() {
if (mask) return mask->MutableDataRaw();
temp_mask_buffer =
GetScratchBuffer<void>(mask_element_count * (UseBitmask ? sizeof(BitmaskElementType) : sizeof(bool)));
return temp_mask_buffer.get();
}();
const fast_divmod fdm_dim(gsl::narrow_cast<int>(dim));
PhiloxGenerator& generator = generator_ ? *generator_ : PhiloxGenerator::Default();
utils::MLTypeCallDispatcher<float, MLFloat16, double, BFloat16> t_disp(X->GetElementType());
return t_disp.InvokeRet<Status, BiasDropoutComputeImpl>(GetDeviceProp(), Stream(), N, mask_element_count, fdm_dim,
ratio_data, generator, *X, *bias, residual, *Y, mask_data,
has_same_shape_bias, UseBitmask);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/framework/random_generator.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
void BiasDropoutKernelImpl(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N,
const int64_t mask_element_count, const fast_divmod fdm_dim, const float ratio,
PhiloxGenerator& generator, const T* X_data, const T* bias_data, const T* residual_data,
T* Y_data, void* mask_data, bool has_same_shape_bias, bool use_bitmask);
template <bool UseBitmask>
class BiasDropout final : public RocmKernel {
public:
BiasDropout(const OpKernelInfo& info) : RocmKernel(info) {
int64_t seed = 0;
if (info.GetAttr<int64_t>("seed", &seed).IsOK()) {
generator_ = std::make_unique<PhiloxGenerator>(static_cast<uint64_t>(seed));
}
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
mutable std::unique_ptr<PhiloxGenerator> generator_;
static constexpr float default_ratio_ = 0.5f;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Modifications Copyright (c) Microsoft. */
#include "contrib_ops/rocm/math/bias_dropout.h"
#include <hiprand_kernel.h>
#include <algorithm>
#include "core/providers/rocm/cu_inc/bitmask.cuh"
namespace onnxruntime {
namespace contrib {
namespace rocm {
constexpr int kBlockSize = 256;
constexpr int kNumUnroll = 4;
template <typename T, bool HasSameShapeBias, bool HasResidual, bool UseBitmask>
__global__ void BiasDropoutKernel(const HIP_LONG N, const HIP_LONG mask_element_count, const int step_size,
const int steps_per_thread, const fast_divmod fdm_bits_per_element,
const fast_divmod fdm_dim, const float ratio,
const std::pair<uint64_t, uint64_t> seeds, const T* X_data, const T* bias_data,
const T* residual_data, T* Y_data, void* mask_data) {
HIP_LONG idx = blockDim.x * blockIdx.x + threadIdx.x;
const float p = 1.0f - ratio;
const float scale = 1.0f / p;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seeds.first, idx, seeds.second, &state);
float4 rand;
// We ensure every thread generates the same number of random numbers (by rounding
// up the size) and at the same timestep (by syncing threads).
// From ROCM hiprand documentation:
// The Philox_4x32_10 algorithm is closely tied to the thread and block count.
// Each thread computes 4 random numbers in the same time thus the most efficient
// use of Philox_4x32_10 is to generate a multiple of 4 times number of threads.
for (int i = 0; i < steps_per_thread; ++i) {
HIP_LONG id = idx * kNumUnroll + i * step_size;
rand = hiprand_uniform4(&state);
BitmaskElementType thread_bitmask = 0;
// actual computation
#pragma unroll
for (int i = 0; i < kNumUnroll; ++i) {
HIP_LONG li = id + i;
if (li < N) {
float bias;
if (HasSameShapeBias) {
bias = static_cast<float>(bias_data[li]);
} else {
int offset = fdm_dim.mod(li);
bias = static_cast<float>(bias_data[offset]);
}
bool mask = (&rand.x)[i] < p;
float output_data = (static_cast<float>(X_data[li]) + bias) * mask * scale;
if (HasResidual) {
output_data += static_cast<float>(residual_data[li]);
}
Y_data[li] = static_cast<T>(output_data);
if (UseBitmask) {
thread_bitmask |= (mask << i);
} else {
reinterpret_cast<bool*>(mask_data)[li] = mask;
}
}
}
if (UseBitmask) {
SetBitmask<kNumUnroll>(id, mask_element_count, fdm_bits_per_element, thread_bitmask,
reinterpret_cast<BitmaskElementType*>(mask_data));
}
__syncthreads();
}
}
template <typename T, bool HasSameShapeBias, bool HasResidual, bool UseBitmask>
__global__ void BiasDropoutVectorizedKernel(const HIP_LONG N, const HIP_LONG mask_element_count, const int step_size,
const int steps_per_thread, const fast_divmod fdm_bits_per_element,
const fast_divmod fdm_dim, const float ratio,
const std::pair<uint64_t, uint64_t> seeds, const T* X_data,
const T* bias_data, const T* residual_data, T* Y_data, void* mask_data) {
HIP_LONG idx = blockDim.x * blockIdx.x + threadIdx.x;
const float p = 1.0f - ratio;
const float scale = 1.0f / p;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seeds.first, idx, seeds.second, &state);
float4 rand;
// using vectorized data load/store approach when N % 4 == 0
// since this is typical case for input shape size
using LoadT = aligned_vector<T, kNumUnroll>;
using MaskLoadT = aligned_vector<bool, kNumUnroll>;
using ResidualLoadT = aligned_vector<T, kNumUnroll>;
for (int i = 0; i < steps_per_thread; ++i) {
HIP_LONG id = idx * kNumUnroll + i * step_size;
rand = hiprand_uniform4(&state);
BitmaskElementType thread_bitmask = 0;
if (id < N) {
// vectorized load into storage
T bias_vec[kNumUnroll];
if (HasSameShapeBias) {
LoadT* value0 = reinterpret_cast<LoadT*>(&bias_vec);
*value0 = *reinterpret_cast<const LoadT*>(&bias_data[id]);
}
T src[kNumUnroll];
LoadT* value1 = reinterpret_cast<LoadT*>(&src);
*value1 = *reinterpret_cast<const LoadT*>(&X_data[id]);
T residual[kNumUnroll];
if (HasResidual) {
ResidualLoadT* value2 = reinterpret_cast<ResidualLoadT*>(&residual);
*value2 = *reinterpret_cast<const ResidualLoadT*>(&residual_data[id]);
}
T r[kNumUnroll];
bool masks[kNumUnroll];
// actual computation
#pragma unroll
for (int ii = 0; ii < kNumUnroll; ii++) {
float bias;
if (HasSameShapeBias) {
bias = static_cast<float>(bias_vec[ii]);
} else {
int offset = fdm_dim.mod(id + ii);
bias = static_cast<float>(bias_data[offset]);
}
bool mask = (&rand.x)[ii] < p;
float output_data = (static_cast<float>(src[ii]) + bias) * mask * scale;
if (HasResidual) {
output_data += static_cast<float>(residual[ii]);
}
r[ii] = static_cast<T>(output_data);
if (UseBitmask) {
thread_bitmask |= (mask << ii);
} else {
masks[ii] = mask;
}
}
// Vectorized writes for mask_data & Y_data
*(reinterpret_cast<LoadT*>(&Y_data[id])) = *reinterpret_cast<LoadT*>(&r[0]);
if (!UseBitmask) {
*(reinterpret_cast<MaskLoadT*>(&reinterpret_cast<bool*>(mask_data)[id])) =
*reinterpret_cast<MaskLoadT*>(&masks[0]);
}
}
if (UseBitmask) {
SetBitmask<kNumUnroll>(id, mask_element_count, fdm_bits_per_element, thread_bitmask,
reinterpret_cast<BitmaskElementType*>(mask_data));
}
__syncthreads();
}
}
#define LAUNCH_BIAS_DROPOUT_KERNEL(FuncName, HasSameShapeBias, HasResidual, UseBitmask) \
hipLaunchKernelGGL(HIP_KERNEL_NAME(FuncName<T, HasSameShapeBias, HasResidual, UseBitmask>), grid_size, kBlockSize, 0, stream, \
static_cast<HIP_LONG>(N), static_cast<HIP_LONG>(mask_element_count), step_size, steps_per_thread, \
fdm_bits_per_element, fdm_dim, ratio, seeds, X_data, bias_data, residual_data, Y_data, mask_data)
#define HANDLE_BIAS_DROPOUT_USE_BITMASK(FuncName, HasSameShapeBias, HasResidual) \
if (use_bitmask) { \
LAUNCH_BIAS_DROPOUT_KERNEL(FuncName, HasSameShapeBias, HasResidual, true); \
} else { \
LAUNCH_BIAS_DROPOUT_KERNEL(FuncName, HasSameShapeBias, HasResidual, false); \
}
#define HANDLE_BIAS_DROPOUT_HAS_RESIDUAL(FuncName, HasSameShapeBias) \
if (residual_data) { \
HANDLE_BIAS_DROPOUT_USE_BITMASK(FuncName, HasSameShapeBias, true); \
} else { \
HANDLE_BIAS_DROPOUT_USE_BITMASK(FuncName, HasSameShapeBias, false); \
}
#define HANDLE_BIAS_DROPOUT_HAS_SAME_SHAPE_BIAS(FuncName) \
if (has_same_shape_bias) { \
HANDLE_BIAS_DROPOUT_HAS_RESIDUAL(FuncName, true); \
} else { \
HANDLE_BIAS_DROPOUT_HAS_RESIDUAL(FuncName, false); \
}
template <typename T>
void BiasDropoutKernelImpl(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N,
const int64_t mask_element_count, const fast_divmod fdm_dim, const float ratio,
PhiloxGenerator& generator, const T* X_data, const T* bias_data, const T* residual_data,
T* Y_data, void* mask_data, bool has_same_shape_bias, bool use_bitmask) {
const int blocks_per_sm = prop.maxThreadsPerMultiProcessor / kBlockSize;
const int grid_size =
std::min(prop.multiProcessorCount * blocks_per_sm, static_cast<int>(CeilDiv(N, kBlockSize * kNumUnroll)));
// Compute the number of random numbers generated by each thread, and increment philox generator offset by that
// amount.
const int step_size = kBlockSize * grid_size * kNumUnroll;
const int steps_per_thread = static_cast<int>(CeilDiv(N, step_size));
auto seeds = generator.NextPhiloxSeeds(static_cast<uint64_t>(steps_per_thread * kNumUnroll));
fast_divmod fdm_bits_per_element(kNumBitsPerBitmaskElement);
if (N % kNumUnroll != 0) {
HANDLE_BIAS_DROPOUT_HAS_SAME_SHAPE_BIAS(BiasDropoutKernel);
} else {
HANDLE_BIAS_DROPOUT_HAS_SAME_SHAPE_BIAS(BiasDropoutVectorizedKernel);
}
}
#undef HANDLE_BIAS_DROPOUT_HAS_SAME_SHAPE_BIAS
#undef HANDLE_BIAS_DROPOUT_HAS_RESIDUAL
#undef HANDLE_BIAS_DROPOUT_USE_BITMASK
#undef LAUNCH_BIAS_DROPOUT_KERNEL
#define SPECIALIZED_BIAS_DROPOUT_IMPL(T) \
template void BiasDropoutKernelImpl<T>( \
const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const int64_t mask_element_count, \
const fast_divmod fdm_dim, const float ratio, PhiloxGenerator& generator, const T* X_data, const T* bias_data, \
const T* residual_data, T* Y_data, void* mask_data, bool has_same_shape_bias, bool use_bitmask);
SPECIALIZED_BIAS_DROPOUT_IMPL(float)
SPECIALIZED_BIAS_DROPOUT_IMPL(double)
SPECIALIZED_BIAS_DROPOUT_IMPL(half)
SPECIALIZED_BIAS_DROPOUT_IMPL(BFloat16)
#undef SPECIALIZED_BIAS_DROPOUT_IMPL
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment