Commit 23781e21 authored by pengcheng888's avatar pengcheng888
Browse files

issue/521 - support topksoftmax on metax GPU.

parent f5e6d729
...@@ -5,16 +5,13 @@ ...@@ -5,16 +5,13 @@
#include <cub/block/block_radix_sort.cuh> #include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh> #include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh> #include <cub/block/block_store.cuh>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
template <typename T> template <typename T>
inline __device__ float exp_func(T x) { inline __device__ float exp_func(T x) {
float data; float data;
if constexpr (std::is_same_v<T, float>) { if constexpr (std::is_same_v<T, float>) {
data = x; data = x;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) { } else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
data = __bfloat162float(x); data = __bfloat162float(x);
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
data = __half2float(x); data = __half2float(x);
......
#ifndef __TOPKSOFTMAX_METAX_CUH__
#define __TOPKSOFTMAX_METAX_CUH__
#include "../topksoftmax.h"
DESCRIPTOR(metax)
#endif
#include "../../../devices/metax/metax_common.h"
#include "topksoftmax_metax.cuh"
#include "../../../devices/metax/metax_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
namespace op::topksoftmax::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_desc) {
auto result = TopksoftmaxInfo::create(x_desc);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <int BLOCK_SIZE = 128>
infiniStatus_t launch_topksoftmax(float *d_values_out, int *d_indices_out, const void *d_input, const size_t N, const size_t width, const size_t topk, const bool norm, infiniDtype_t xtype, hcStream_t stream) {
const int block_threads = BLOCK_SIZE;
dim3 blocks(static_cast<unsigned int>(N));
dim3 threads(block_threads);
if (xtype == INFINI_DTYPE_F32) {
softmax_topk_row_kernel<float, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (float *)d_input, N, width, topk, norm);
} else if (xtype == INFINI_DTYPE_F16) {
softmax_topk_row_kernel<half, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (half *)d_input, N, width, topk, norm);
} else if (xtype == INFINI_DTYPE_BF16) {
softmax_topk_row_kernel<__hpcc_bfloat16, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (__hpcc_bfloat16 *)d_input, N, width, topk, norm);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
}; // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
float *values,
int *indices,
const void *x,
const size_t topk,
const bool norm,
void *stream_) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
size_t N = _info.N;
size_t width = _info.width;
auto stream = reinterpret_cast<hcStream_t>(stream_);
if (width <= 128) {
launch_topksoftmax<128>(values, indices, x, N, width, topk, norm, _info.xtype, stream);
} else if (width <= 256) {
launch_topksoftmax<256>(values, indices, x, N, width, topk, norm, _info.xtype, stream);
} else if (width <= 512) {
launch_topksoftmax<512>(values, indices, x, N, width, topk, norm, _info.xtype, stream);
} else {
return INFINI_STATUS_INTERNAL_ERROR;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topksoftmax::metax
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) #if defined(ENABLE_NVIDIA_API)
#include "nvidia/topksoftmax_nvidia.cuh" #include "nvidia/topksoftmax_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API
#include "metax/topksoftmax_metax.cuh"
#endif
__C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle, __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
infiniopTopksoftmaxDescriptor_t *desc_ptr, infiniopTopksoftmaxDescriptor_t *desc_ptr,
...@@ -24,6 +27,9 @@ __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle, ...@@ -24,6 +27,9 @@ __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
#endif #endif
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia); CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
} }
...@@ -45,6 +51,9 @@ __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescri ...@@ -45,6 +51,9 @@ __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescri
#endif #endif
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia); GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif #endif
} }
...@@ -71,6 +80,9 @@ __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, voi ...@@ -71,6 +80,9 @@ __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, voi
#endif #endif
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
} }
...@@ -92,6 +104,9 @@ __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescr ...@@ -92,6 +104,9 @@ __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescr
#endif #endif
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia); DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif #endif
} }
......
...@@ -115,7 +115,7 @@ if __name__ == "__main__": ...@@ -115,7 +115,7 @@ if __name__ == "__main__":
# x_shape, x_strides, topk, norm # x_shape, x_strides, topk, norm
((1, 32), None, 4, True), ((1, 32), None, 4, True),
((8, 20), None, 8, False), ((8, 20), None, 8, False),
((2, 128), None, 10, True) ((2, 64), None, 6, True)
] ]
_TENSOR_DTYPES_ = [np.float32, np.float16] _TENSOR_DTYPES_ = [np.float32, np.float16]
for dtype in _TENSOR_DTYPES_: for dtype in _TENSOR_DTYPES_:
......
...@@ -28,8 +28,8 @@ from libinfiniop import ( ...@@ -28,8 +28,8 @@ from libinfiniop import (
_TEST_CASES_ = [ _TEST_CASES_ = [
# x_shape, x_stride, topk, norm # x_shape, x_stride, topk, norm
((1, 10), None, 7, True), ((1, 10), None, 7, True),
((2, 20), None, 4, True), ((8, 20), None, 4, True),
((1, 128), None, 10, True), ((2, 64), None, 6, True),
] ]
# w (weight) types # w (weight) types
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment