Commit 1a91fcc2 authored by gaoqiong's avatar gaoqiong
Browse files

add dtk所需文件

parent a144865d
Pipeline #492 failed with stages
in 0 seconds
/*
Copyright (c) NVIDIA Corporation and Microsoft Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#pragma once
#include "core/common/common.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Launch the softmax kernels that does not use compact memory.
Status LaunchLongformerSoftmaxSimpleKernel(
hipStream_t stream,
rocblas_handle rocblas,
void* workspace, // softmax space
const void* q, // transposed Q with shape (B, N, S, H)
const void* k, // transposed K with shape (B, N, S, H)
const void* v, // transposed V with shape (B, N, S, H)
const void* attention_mask, // attention mask with shape (B, S), with value 0.0 not masked, and -10000.0 masked.
const void* global_q, // Q for global tokens with shape (B, N, S, H)
const void* global_k, // K for global tokens with shape (B, N, S, H)
const void* global_v, // V for global tokens with shape (B, N, S, H)
const int* global_attention, // global attention flags with shape (B, S), with value 0 for local and 1 for global.
const int* global_index, // Global index with shape (B, S)
const int* batch_global_num, // Number of global tokens per batch with shape (B, 1)
void* pinned_buffer, // Pinned memory in CPU. Number of global tokens per batch with shape (B, 1)
void* output, // output with shape (B, N, S, H)
float scaler, // scalar
int batch_size, // batch size
int sequence_length, // sequence length
int num_heads, // number of heads
int head_size, // hidden size per head
int attention_window, // one sided windows size
size_t element_size);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
Copyright (c) NVIDIA Corporation and Microsoft Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <hipcub/hipcub.hpp>
#include <hipcub/device/device_partition.hpp>
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "longformer_global_impl.h"
using namespace onnxruntime::rocm;
using namespace hipcub;
namespace onnxruntime {
namespace contrib {
namespace rocm {
size_t GetGlobalScratchSize(int sequence_length) {
// Global Index scratch layout:
// [sequence_index: int S][tmp_storage: int 1024x1]
return sizeof(int) * (sequence_length + 1024);
}
__global__ void InitSequenceIndexKernel(int* sequence_index, int sequence_length) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < sequence_length; i += blockDim.x) {
sequence_index[i] = i;
}
}
Status BuildGlobalIndex(
const hipDeviceProp_t& device_prop,
hipStream_t stream,
const int* global_attention,
int batch_size,
int sequence_length,
int* global_index,
int* batch_global_num,
void* scratch,
size_t scratch_size) {
int* sequence_index = (int*)scratch;
int* tmp_storage = sequence_index + sequence_length;
const int threads = device_prop.maxThreadsPerBlock;
int blocks = CeilDiv(sequence_length, threads);
hipLaunchKernelGGL(InitSequenceIndexKernel, blocks, threads, 0, stream, sequence_index, sequence_length);
// Determine temporary device storage size.
// For int* inputs/outputs, it need 767 bytes. We reserved 1024*4 bytes, which shall be enough.
size_t temp_storage_bytes = 0;
HIP_RETURN_IF_ERROR(hipcub::DevicePartition::Flagged(
NULL, temp_storage_bytes, sequence_index,
global_attention, global_index, batch_global_num, sequence_length, stream));
if (temp_storage_bytes + sizeof(int) * sequence_length > scratch_size) {
ORT_THROW("LongformerAttention scratch space is not large enough. Temp storage bytes are", temp_storage_bytes);
}
// Find the global attention indices and number of global attention tokens
for (int i = 0; i < batch_size; ++i) {
HIP_RETURN_IF_ERROR(hipcub::DevicePartition::Flagged(
reinterpret_cast<void*>(tmp_storage), temp_storage_bytes, sequence_index,
global_attention + i * sequence_length, global_index + i * sequence_length,
batch_global_num + i, sequence_length, stream));
}
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Size of global Index scratch in bytes.
size_t GetGlobalScratchSize(int sequence_length);
// Find the global attention indices and number of global attention tokens
Status BuildGlobalIndex(
const hipDeviceProp_t& device_prop,
hipStream_t stream,
const int* global_attention,
int batch_size,
int sequence_length,
int* global_index,
int* batch_global_num,
void* scratch,
size_t scratch_size);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "ngram_repeat_block.h"
#include "ngram_repeat_block_impl.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(
NGramRepeatBlock,
kMSDomain,
1,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("Tid", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
NGramRepeatBlock);
using namespace ONNX_NAMESPACE;
NGramRepeatBlock::NGramRepeatBlock(const OpKernelInfo& info) : RocmKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("ngram_size", &ngram_size_).IsOK());
ORT_ENFORCE(ngram_size_ > 0);
}
Status NGramRepeatBlock::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_ids = context->Input<Tensor>(0);
const Tensor* scores = context->Input<Tensor>(1);
Tensor* output = context->Output(0, scores->Shape());
const auto* scores_source = static_cast<const float*>(scores->DataRaw());
auto* scores_target = static_cast<float*>(output->MutableDataRaw());
if (scores_source != scores_target) {
HIP_RETURN_IF_ERROR(hipMemcpyAsync(scores_target, scores_source, scores->Shape().Size() * sizeof(float), hipMemcpyDeviceToDevice, Stream()));
}
const auto& input_ids_dims = input_ids->Shape().GetDims();
const auto& scores_dims = scores->Shape().GetDims();
ORT_ENFORCE(input_ids_dims.size() == 2);
ORT_ENFORCE(scores_dims.size() == 2);
int64_t batch_size = input_ids_dims[0];
int64_t cur_len = input_ids_dims[1];
ORT_ENFORCE(scores_dims[0] == batch_size);
int64_t vocab_size = scores_dims[1];
if (cur_len + 1 < ngram_size_) {
return Status::OK();
}
const auto* input_ids_data = static_cast<const int64_t*>(input_ids->DataRaw(input_ids->DataType()));
NGramRepeatBlockImpl(
Stream(),
input_ids_data,
scores_target,
gsl::narrow_cast<int>(batch_size),
gsl::narrow_cast<int>(cur_len - 1),
gsl::narrow_cast<int>(cur_len),
gsl::narrow_cast<int>(vocab_size),
gsl::narrow_cast<int>(1),
gsl::narrow_cast<int>(ngram_size_));
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
class NGramRepeatBlock final : public RocmKernel {
public:
NGramRepeatBlock(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
int64_t ngram_size_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
/*
Kernel implementation for blocking repeated n-grams.
*/
#include "core/providers/rocm/cu_inc/common.cuh"
#include "contrib_ops/rocm/bert/ngram_repeat_block_impl.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
// Ban repeated ngrams of length = 'no_repeat_ngram_size'
__global__ void banRepeatedTokens(const int64_t* __restrict__ tokens,
float* __restrict__ lprobs,
int max_predict_len, int vocab_size,
int no_repeat_ngram_size) {
auto row = blockIdx.x;
auto col = threadIdx.x;
auto start = row * (max_predict_len) + col;
// Each thread compares ngram starting from
// thread index with final ngram starting from
// step - no_repeat_ngram_size +2
auto check_start_pos = blockDim.x;
auto lprob_start = row * vocab_size;
bool is_banned = true;
extern __shared__ int64_t tokens_shm[];
tokens_shm[col] = tokens[start];
if (col == blockDim.x - 1) {
for (int i = 1; i < no_repeat_ngram_size; i++) {
if (col + i < max_predict_len) {
tokens_shm[col + i] = tokens[start + i];
}
}
}
__syncthreads();
for (int k = 0; k < no_repeat_ngram_size - 1; k++) {
if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) {
is_banned = false;
}
}
if (is_banned == true) {
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
lprobs[lprob_start + token_to_be_banned] = -INFINITY;
}
}
// Allocate blocks and threads based on
// batch size and sequence length and launch
// kernel
void NGramRepeatBlockImpl(
hipStream_t stream,
const int64_t* tokens_ptr,
float* scores_ptr,
int bsz,
int step,
int max_predict_len,
int vocab_size,
int beam_size,
int no_repeat_ngram_size) {
int threads = step - no_repeat_ngram_size + 2;
if (threads <= 0) return;
int blocks = bsz * beam_size;
int shared_mem_size = (step + 1) * sizeof(int64_t);
// Launching N blocks where N is number of samples in a batch (beams*bsz)
// Launching T threads where T is number of previous ngrams in a sample
// Allocating shared mem per block for fastser access of input tokens since
// each token will be accessed N times to compare with current Ngram where
// N is Ngram size.
hipLaunchKernelGGL(banRepeatedTokens, blocks, threads, shared_mem_size, stream,
tokens_ptr, scores_ptr, max_predict_len, vocab_size, no_repeat_ngram_size);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
void NGramRepeatBlockImpl(
hipStream_t stream,
const int64_t* tokens_ptr,
float* scores_ptr,
int bsz,
int step,
int max_predict_len,
int vocab_size,
int beam_size,
int no_repeat_ngram_size);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
class RemovePadding final : public RocmKernel {
public:
RemovePadding(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
class RestorePadding final : public RocmKernel {
public:
RestorePadding(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment