"docs/git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "0b4311cd7d5487f0e5bc37d2bbe9654ce81c0fdb"
Unverified Commit 05a2e149 authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

issue/383: Add logsoftmax ops


Co-authored-by: default avatarwawahejun <hejunlbbc@gmail.com>
Co-authored-by: default avatarzhuyue <zhuyue@qiyuanlab.com>
parent 79dbccd9
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "infiniop/ops/attention.h" #include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h" #include "infiniop/ops/clip.h"
#include "infiniop/ops/logsoftmax.h"
#include "infiniop/ops/conv.h" #include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h" #include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/gemm.h" #include "infiniop/ops/gemm.h"
......
#ifndef __INFINIOP_LOGSOFTMAX_API_H__
#define __INFINIOP_LOGSOFTMAX_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopLogSoftmaxDescriptor_t;
__C __export infiniStatus_t infiniopCreateLogSoftmaxDescriptor(infiniopHandle_t handle,
infiniopLogSoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc);
__C __export infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopLogSoftmax(infiniopLogSoftmaxDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream);
__C __export infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescriptor_t desc);
#endif
...@@ -17,6 +17,7 @@ def run_tests(args): ...@@ -17,6 +17,7 @@ def run_tests(args):
"causal_softmax.py", "causal_softmax.py",
"clip.py", "clip.py",
"gemm.py", "gemm.py",
"logsoftmax.py",
"mul.py", "mul.py",
"random_sample.py", "random_sample.py",
"rearrange.py", "rearrange.py",
......
#include "logsoftmax_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
#include <algorithm>
#include <cmath>
namespace op::logsoftmax::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto result = LogSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tx, typename Ty>
infiniStatus_t logsoftmax(const LogSoftmaxInfo *info, Ty *y, const Tx *x) {
#pragma omp parallel for
for (ptrdiff_t batch = 0; batch < ptrdiff_t(info->batch_size); batch++) {
ptrdiff_t y_offset, x_offset;
if (info->ndim == 3) {
// For 3D tensors, convert linear batch index back to 2D indices
ptrdiff_t batch_idx = batch / info->seq_len;
ptrdiff_t seq_idx = batch % info->seq_len;
y_offset = batch_idx * info->y_stride_0 + seq_idx * info->y_stride_1;
x_offset = batch_idx * info->x_stride_0 + seq_idx * info->x_stride_1;
} else {
// For 2D tensors, use the flattened strides
y_offset = batch * info->y_stride_b;
x_offset = batch * info->x_stride_b;
}
Ty *y_ = y + y_offset;
const Tx *x_ = x + x_offset;
// Find max value for numerical stability
float max_val;
if constexpr (std::is_same<Tx, fp16_t>::value || std::is_same<Tx, bf16_t>::value) {
max_val = op::common_cpu::reduce_op::max(x_, info->probs_size, info->x_stride_p);
} else {
max_val = op::common_cpu::reduce_op::max(x_, info->probs_size, info->x_stride_p);
}
// Compute exp(x - max) and sum
float sum = 0.0f;
for (size_t i = 0; i < info->probs_size; i++) {
float x_val;
if constexpr (std::is_same<Tx, fp16_t>::value || std::is_same<Tx, bf16_t>::value) {
x_val = utils::cast<float>(x_[i * info->x_stride_p]);
} else {
x_val = x_[i * info->x_stride_p];
}
sum += std::exp(x_val - max_val);
}
// Compute log(sum)
float log_sum = std::log(sum);
// Compute log_softmax = x - max - log(sum)
for (size_t i = 0; i < info->probs_size; i++) {
float x_val;
if constexpr (std::is_same<Tx, fp16_t>::value || std::is_same<Tx, bf16_t>::value) {
x_val = utils::cast<float>(x_[i * info->x_stride_p]);
} else {
x_val = x_[i * info->x_stride_p];
}
float result = x_val - max_val - log_sum;
if constexpr (std::is_same<Ty, fp16_t>::value || std::is_same<Ty, bf16_t>::value) {
y_[i * info->y_stride_p] = utils::cast<Ty>(result);
} else {
y_[i * info->y_stride_p] = result;
}
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream) const {
// Handle different input/output dtype combinations
if (_info.x_dtype == INFINI_DTYPE_F16) {
if (_info.y_dtype == INFINI_DTYPE_F16) {
return logsoftmax<fp16_t, fp16_t>(&_info, (fp16_t *)y, (const fp16_t *)x);
} else if (_info.y_dtype == INFINI_DTYPE_BF16) {
return logsoftmax<fp16_t, bf16_t>(&_info, (bf16_t *)y, (const fp16_t *)x);
} else if (_info.y_dtype == INFINI_DTYPE_F32) {
return logsoftmax<fp16_t, float>(&_info, (float *)y, (const fp16_t *)x);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.x_dtype == INFINI_DTYPE_BF16) {
if (_info.y_dtype == INFINI_DTYPE_F16) {
return logsoftmax<bf16_t, fp16_t>(&_info, (fp16_t *)y, (const bf16_t *)x);
} else if (_info.y_dtype == INFINI_DTYPE_BF16) {
return logsoftmax<bf16_t, bf16_t>(&_info, (bf16_t *)y, (const bf16_t *)x);
} else if (_info.y_dtype == INFINI_DTYPE_F32) {
return logsoftmax<bf16_t, float>(&_info, (float *)y, (const bf16_t *)x);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.x_dtype == INFINI_DTYPE_F32) {
if (_info.y_dtype == INFINI_DTYPE_F16) {
return logsoftmax<float, fp16_t>(&_info, (fp16_t *)y, (const float *)x);
} else if (_info.y_dtype == INFINI_DTYPE_BF16) {
return logsoftmax<float, bf16_t>(&_info, (bf16_t *)y, (const float *)x);
} else if (_info.y_dtype == INFINI_DTYPE_F32) {
return logsoftmax<float, float>(&_info, (float *)y, (const float *)x);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::logsoftmax::cpu
#ifndef __LOGSOFTMAX_CPU_H__
#define __LOGSOFTMAX_CPU_H__
#include "../logsoftmax.h"
DESCRIPTOR(cpu)
#endif
#ifndef __LOGSOFTMAX_KERNEL_CUH__
#define __LOGSOFTMAX_KERNEL_CUH__
#include <cub/block/block_reduce.cuh>
#include <type_traits>
template <unsigned int BLOCK_SIZE, typename Tdata_out, typename Tdata_in, typename Tcompute>
__device__ void logSoftmaxKernel(
Tdata_out *y, const Tdata_in *x,
size_t batch_size, size_t probs_size, size_t ndim, size_t seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_p,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_p,
ptrdiff_t y_stride_0, ptrdiff_t y_stride_1,
ptrdiff_t x_stride_0, ptrdiff_t x_stride_1) {
typedef cub::BlockReduce<Tcompute, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ Tcompute shared_max_val;
__shared__ Tcompute shared_sum_exp;
int batch_idx = blockIdx.x;
int tid = threadIdx.x;
if (batch_idx >= batch_size) {
return;
}
// Calculate correct memory offsets for 3D tensors
ptrdiff_t y_offset, x_offset;
if (ndim == 3) {
// For 3D tensors, convert linear batch index back to 2D indices
ptrdiff_t batch_dim_idx = batch_idx / seq_len;
ptrdiff_t seq_dim_idx = batch_idx % seq_len;
y_offset = batch_dim_idx * y_stride_0 + seq_dim_idx * y_stride_1;
x_offset = batch_dim_idx * x_stride_0 + seq_dim_idx * x_stride_1;
} else {
// For 2D tensors, use the flattened strides
y_offset = batch_idx * y_stride_b;
x_offset = batch_idx * x_stride_b;
}
const Tdata_in *x_batch = x + x_offset;
Tdata_out *y_batch = y + y_offset;
// Find maximum value for numerical stability
Tcompute max_val = static_cast<Tcompute>(-INFINITY);
for (int i = tid; i < probs_size; i += BLOCK_SIZE) {
if (i < probs_size) { // Add boundary check
Tcompute val = static_cast<Tcompute>(x_batch[i * x_stride_p]);
if constexpr (std::is_same_v<Tcompute, float>) {
max_val = fmaxf(max_val, val);
} else {
max_val = fmax(max_val, val);
}
}
}
max_val = BlockReduce(temp_storage).Reduce(max_val, cub::Max());
if (tid == 0) {
shared_max_val = max_val;
}
__syncthreads();
// Compute sum of exp(x - max)
Tcompute sum_exp = static_cast<Tcompute>(0.0);
for (int i = tid; i < probs_size; i += BLOCK_SIZE) {
if (i < probs_size) { // Add boundary check
Tcompute val = static_cast<Tcompute>(x_batch[i * x_stride_p]);
if constexpr (std::is_same_v<Tcompute, float>) {
sum_exp += expf(val - shared_max_val);
} else {
sum_exp += exp(val - shared_max_val);
}
}
}
sum_exp = BlockReduce(temp_storage).Sum(sum_exp);
if (tid == 0) {
shared_sum_exp = sum_exp;
}
__syncthreads();
// Compute log_softmax = x - max - log(sum_exp)
Tcompute log_sum_exp;
if constexpr (std::is_same_v<Tcompute, float>) {
log_sum_exp = logf(shared_sum_exp);
} else {
log_sum_exp = log(shared_sum_exp);
}
for (int i = tid; i < probs_size; i += BLOCK_SIZE) {
if (i < probs_size) { // Add boundary check
Tcompute val = static_cast<Tcompute>(x_batch[i * x_stride_p]);
Tcompute result = val - shared_max_val - log_sum_exp;
y_batch[i * y_stride_p] = static_cast<Tdata_out>(result);
}
}
}
template <unsigned int BLOCK_SIZE, typename Tdata_out, typename Tdata_in, typename Tcompute>
__global__ void logSoftmax(
Tdata_out *y, const Tdata_in *x,
size_t batch_size, size_t probs_size, size_t ndim, size_t seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_p,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_p,
ptrdiff_t y_stride_0, ptrdiff_t y_stride_1,
ptrdiff_t x_stride_0, ptrdiff_t x_stride_1) {
logSoftmaxKernel<BLOCK_SIZE, Tdata_out, Tdata_in, Tcompute>(y, x, batch_size, probs_size, ndim, seq_len, y_stride_b, y_stride_p, x_stride_b, x_stride_p, y_stride_0, y_stride_1, x_stride_0, x_stride_1);
}
#endif // __LOGSOFTMAX_KERNEL_CUH__
#ifndef __LOGSOFTMAX_INFO_H__
#define __LOGSOFTMAX_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::logsoftmax {
class LogSoftmaxInfo {
LogSoftmaxInfo() = default;
public:
infiniDtype_t x_dtype;
infiniDtype_t y_dtype;
size_t batch_size;
size_t probs_size;
// Original tensor dimensions for 3D support
size_t ndim;
size_t seq_len; // Only used for 3D tensors
// Flattened strides for CPU iteration
ptrdiff_t y_stride_b;
ptrdiff_t y_stride_p;
ptrdiff_t x_stride_b;
ptrdiff_t x_stride_p;
// Original 3D strides for correct memory access
ptrdiff_t y_stride_0, y_stride_1, y_stride_2;
ptrdiff_t x_stride_0, x_stride_1, x_stride_2;
static utils::Result<LogSoftmaxInfo> create(infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc) {
auto x_dtype = x_desc->dtype();
auto y_dtype = y_desc->dtype();
CHECK_DTYPE(x_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
// Check the output data type, and any dtype is allowed to output fp32.
CHECK_DTYPE(y_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
auto x_shape = x_desc->shape();
auto y_shape = y_desc->shape();
CHECK_SAME_SHAPE(x_shape, y_shape);
auto ndim = x_desc->ndim();
if (ndim < 2 || ndim > 3) {
CHECK_STATUS(INFINI_STATUS_BAD_TENSOR_SHAPE);
}
size_t batch_size, probs_size, seq_len = 0;
if (ndim == 2) {
batch_size = x_shape[0];
probs_size = x_shape[1];
} else { // ndim == 3
batch_size = x_shape[0] * x_shape[1];
probs_size = x_shape[2];
seq_len = x_shape[1];
}
// Store original strides for all dimensions
ptrdiff_t y_stride_0 = 0, y_stride_1 = 0, y_stride_2 = 0;
ptrdiff_t x_stride_0 = 0, x_stride_1 = 0, x_stride_2 = 0;
if (ndim == 2) {
y_stride_0 = y_desc->stride(0); // First dimension
y_stride_1 = y_desc->stride(1); // Second dimension
x_stride_0 = x_desc->stride(0);
x_stride_1 = x_desc->stride(1);
} else if (ndim == 3) {
y_stride_0 = y_desc->stride(0); // First dimension (batch)
y_stride_1 = y_desc->stride(1); // Second dimension (seq)
y_stride_2 = y_desc->stride(2); // Third dimension (prob)
x_stride_0 = x_desc->stride(0);
x_stride_1 = x_desc->stride(1);
x_stride_2 = x_desc->stride(2);
}
ptrdiff_t y_stride_b, y_stride_p, x_stride_b, x_stride_p;
if (ndim == 2) {
y_stride_b = y_desc->stride(0);
y_stride_p = y_desc->stride(1);
x_stride_b = x_desc->stride(0);
x_stride_p = x_desc->stride(1);
} else { // ndim == 3
// For 3D tensors, flat the first two dimensions
// The CPU implementation expects to iterate through batch_size elements
// where each batch contains probs_size elements
// For flattened iteration, we need stride between consecutive sequences
y_stride_b = y_desc->stride(1); // stride between sequences (20*512 -> 512)
y_stride_p = y_desc->stride(2); // stride within probability dimension
x_stride_b = x_desc->stride(1); // stride between sequences
x_stride_p = x_desc->stride(2); // stride within probability dimension
}
return utils::Result<LogSoftmaxInfo>(LogSoftmaxInfo{
x_dtype,
y_dtype,
batch_size,
probs_size,
ndim,
seq_len,
y_stride_b,
y_stride_p,
x_stride_b,
x_stride_p,
y_stride_0,
y_stride_1,
y_stride_2,
x_stride_0,
x_stride_1,
x_stride_2});
}
};
} // namespace op::logsoftmax
#endif // __LOGSOFTMAX_INFO_H__
#ifndef LOGSOFTMAX_H
#define LOGSOFTMAX_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::logsoftmax::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
LogSoftmaxInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
LogSoftmaxInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *x, \
void *stream) const; \
}; \
}
#endif // LOGSOFTMAX_H
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "logsoftmax_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../cuda/kernel.cuh"
namespace op::logsoftmax::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = LogSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t x_dtype, infiniDtype_t y_dtype,
size_t batch_size, size_t probs_size, size_t ndim, size_t seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_p,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_p,
ptrdiff_t y_stride_0, ptrdiff_t y_stride_1,
ptrdiff_t x_stride_0, ptrdiff_t x_stride_1,
cudaStream_t stream) {
dim3 grid(uint32_t(batch_size), 1, 1);
// Handle mixed precision cases
if (x_dtype == INFINI_DTYPE_F16 && y_dtype == INFINI_DTYPE_F32) {
logSoftmax<BLOCK_SIZE, float, half, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const half *)x,
batch_size, probs_size, ndim, seq_len,
y_stride_b, y_stride_p,
x_stride_b, x_stride_p,
y_stride_0, y_stride_1,
x_stride_0, x_stride_1);
} else if (x_dtype == INFINI_DTYPE_F32 && y_dtype == INFINI_DTYPE_F16) {
logSoftmax<BLOCK_SIZE, half, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((half *)y, (const float *)x,
batch_size, probs_size, ndim, seq_len,
y_stride_b, y_stride_p,
x_stride_b, x_stride_p,
y_stride_0, y_stride_1,
x_stride_0, x_stride_1);
} else if (x_dtype == INFINI_DTYPE_BF16 && y_dtype == INFINI_DTYPE_F32) {
logSoftmax<BLOCK_SIZE, float, __nv_bfloat16, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const __nv_bfloat16 *)x,
batch_size, probs_size, ndim, seq_len,
y_stride_b, y_stride_p,
x_stride_b, x_stride_p,
y_stride_0, y_stride_1,
x_stride_0, x_stride_1);
} else if (x_dtype == INFINI_DTYPE_F32 && y_dtype == INFINI_DTYPE_BF16) {
logSoftmax<BLOCK_SIZE, __nv_bfloat16, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((__nv_bfloat16 *)y, (const float *)x,
batch_size, probs_size, ndim, seq_len,
y_stride_b, y_stride_p,
x_stride_b, x_stride_p,
y_stride_0, y_stride_1,
x_stride_0, x_stride_1);
} else if (x_dtype == INFINI_DTYPE_F16 && y_dtype == INFINI_DTYPE_F16) {
logSoftmax<BLOCK_SIZE, half, half, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((half *)y, (const half *)x,
batch_size, probs_size, ndim, seq_len,
y_stride_b, y_stride_p,
x_stride_b, x_stride_p,
y_stride_0, y_stride_1,
x_stride_0, x_stride_1);
} else if (x_dtype == INFINI_DTYPE_BF16 && y_dtype == INFINI_DTYPE_BF16) {
logSoftmax<BLOCK_SIZE, __nv_bfloat16, __nv_bfloat16, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((__nv_bfloat16 *)y, (const __nv_bfloat16 *)x,
batch_size, probs_size, ndim, seq_len,
y_stride_b, y_stride_p,
x_stride_b, x_stride_p,
y_stride_0, y_stride_1,
x_stride_0, x_stride_1);
} else if (x_dtype == INFINI_DTYPE_F32 && y_dtype == INFINI_DTYPE_F32) {
logSoftmax<BLOCK_SIZE, float, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
batch_size, probs_size, ndim, seq_len,
y_stride_b, y_stride_p,
x_stride_b, x_stride_p,
y_stride_0, y_stride_1,
x_stride_0, x_stride_1);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(
y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len,
_info.y_stride_b, _info.y_stride_p, _info.x_stride_b, _info.x_stride_p,
_info.y_stride_0, _info.y_stride_1, _info.x_stride_0, _info.x_stride_1, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(
y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len,
_info.y_stride_b, _info.y_stride_p, _info.x_stride_b, _info.x_stride_p,
_info.y_stride_0, _info.y_stride_1, _info.x_stride_0, _info.x_stride_1, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(
y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len,
_info.y_stride_b, _info.y_stride_p, _info.x_stride_b, _info.x_stride_p,
_info.y_stride_0, _info.y_stride_1, _info.x_stride_0, _info.x_stride_1, stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::logsoftmax::nvidia
#ifndef __LOGSOFTMAX_NVIDIA_H__
#define __LOGSOFTMAX_NVIDIA_H__
#include "../logsoftmax.h"
DESCRIPTOR(nvidia)
#endif
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/logsoftmax.h"
#ifdef ENABLE_CPU_API
#include "cpu/logsoftmax_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/logsoftmax_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
// #include "metax/logsoftmax_metax.h"
#endif
#ifdef ENABLE_ASCEND_API
// #include "ascend/logsoftmax_ascend.h"
#endif
__C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
infiniopHandle_t handle,
infiniopLogSoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::logsoftmax::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::logsoftmax::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
x_desc);
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
// CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ASCEND_API
// CREATE(INFINI_DEVICE_ASCEND, ascend)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::logsoftmax::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
// GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ASCEND_API
// GET(INFINI_DEVICE_ASCEND, ascend)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopLogSoftmax(
infiniopLogSoftmaxDescriptor_t desc,
void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::logsoftmax::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, y, x, stream);
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ASCEND_API
// CALCULATE(INFINI_DEVICE_ASCEND, ascend)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::logsoftmax::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DESTROY(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ASCEND_API
// DESTROY(INFINI_DEVICE_ASCEND, ascend)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
...@@ -162,6 +162,38 @@ def clip_(lib): ...@@ -162,6 +162,38 @@ def clip_(lib):
] ]
@OpRegister.operator
def logsoftmax_(lib):
lib.infiniopCreateLogSoftmaxDescriptor.restype = c_int32
lib.infiniopCreateLogSoftmaxDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetLogSoftmaxWorkspaceSize.restype = c_int32
lib.infiniopGetLogSoftmaxWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopLogSoftmax.restype = c_int32
lib.infiniopLogSoftmax.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyLogSoftmaxDescriptor.restype = c_int32
lib.infiniopDestroyLogSoftmaxDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator @OpRegister.operator
def conv_(lib): def conv_(lib):
pass pass
......
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# shape, x_stride, y_stride
((3, 3), None, None),
((32, 512), None, None),
((32, 512), (1024, 1), (1024, 1)),
((32, 5, 5), None, None),
((32, 20, 512), None, None),
((32, 20, 512), (20480, 512, 1), None),
((28, 15, 15), None, None),
((1, 1000), None, None),
((16, 50257), None, None),
((4, 8, 256), None, None),
((2, 16, 1024), None, None),
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 1e-2},
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
}
# Mixed precision test cases - support y_dtype == x_dtype or y_dtype == F32
_MIXED_PRECISION_CASES = [
(InfiniDtype.F16, InfiniDtype.F32),
(InfiniDtype.BF16, InfiniDtype.F32),
(InfiniDtype.F16, InfiniDtype.F16),
(InfiniDtype.BF16, InfiniDtype.BF16),
(InfiniDtype.F32, InfiniDtype.F32),
]
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_X = auto()
_INPLACE = [
Inplace.INPLACE_X,
Inplace.OUT_OF_PLACE,
]
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def logsoftmax(x):
"""PyTorch reference implementation of log_softmax"""
return torch.nn.functional.log_softmax(x.to(torch.float32), dim=-1)
def test(
handle,
device,
shape,
x_stride=None,
y_stride=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing LogSoftmax on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
)
x = TestTensor(shape, x_stride, dtype, device)
ans = logsoftmax(x.actual_tensor())
# Convert answer to match input dtype for default behavior
if dtype == InfiniDtype.F16:
ans = ans.to(torch.float16)
elif dtype == InfiniDtype.BF16:
ans = ans.to(torch.bfloat16)
elif dtype == InfiniDtype.F32:
ans = ans.to(torch.float32)
if inplace == Inplace.INPLACE_X:
y = x
else:
y = TestTensor(shape, y_stride, dtype, device) # Default: same dtype as input
if sync is not None:
sync()
descriptor = infiniopOperatorDescriptor_t()
status = LIBINFINIOP.infiniopCreateLogSoftmaxDescriptor(
handle, ctypes.byref(descriptor), y.descriptor, x.descriptor
)
check_error(status)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x.destroy_desc()
y.destroy_desc()
workspace_size = c_uint64(0)
status = LIBINFINIOP.infiniopGetLogSoftmaxWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
check_error(status)
workspace = TestWorkspace(workspace_size.value, x.device)
def lib_logsoftmax():
check_error(
LIBINFINIOP.infiniopLogSoftmax(
descriptor,
workspace.data(),
workspace_size.value,
y.data(),
x.data(),
None,
)
)
lib_logsoftmax()
if sync is not None:
sync()
# Use tolerance based on input dtype for numerical stability
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# Always print debug info for failed cases
actual = y.actual_tensor()
max_diff = torch.max(torch.abs(actual - ans))
is_close = torch.allclose(actual, ans, atol=atol, rtol=rtol)
if DEBUG or not is_close:
print(f"\n=== Debug Info ===")
print(f"Shape: {shape}, Stride: {x_stride}, Dtype: {dtype}")
print(f"Input tensor: {x.torch_tensor()}")
print(f"Expected output: {ans}")
print(f"Actual output: {actual}")
print(f"Max diff: {max_diff}")
print(f"Tolerance: atol={atol}, rtol={rtol}")
print(f"Is close: {is_close}")
print(f"First few values - Actual: {actual.flatten()[:5]}")
print(f"First few values - Expected: {ans.flatten()[:5]}")
if DEBUG:
debug(actual, ans, atol=atol, rtol=rtol)
assert is_close
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: logsoftmax(x.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_logsoftmax(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyLogSoftmaxDescriptor(descriptor))
def test_mixed_precision(
handle,
device,
shape,
x_stride=None,
y_stride=None,
inplace=Inplace.OUT_OF_PLACE,
x_dtype=InfiniDtype.F16,
y_dtype=InfiniDtype.F32,
sync=None,
):
print(
f"Testing LogSoftmax (Mixed) on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} x_dtype:{InfiniDtypeNames[x_dtype]} y_dtype:{InfiniDtypeNames[y_dtype]} inplace:{inplace}"
)
x = TestTensor(shape, x_stride, x_dtype, device)
ans = logsoftmax(x.actual_tensor())
# Convert answer to target dtype for comparison
if y_dtype == InfiniDtype.F16:
ans = ans.to(torch.float16)
elif y_dtype == InfiniDtype.BF16:
ans = ans.to(torch.bfloat16)
elif y_dtype == InfiniDtype.F32:
ans = ans.to(torch.float32)
if inplace == Inplace.INPLACE_X:
# For inplace operations, input and output must have the same dtype
if x_dtype != y_dtype:
print(
f"Skipping inplace test: x_dtype ({InfiniDtypeNames[x_dtype]}) != y_dtype ({InfiniDtypeNames[y_dtype]})"
)
return
y = x
else:
y = TestTensor(shape, y_stride, y_dtype, device)
if sync is not None:
sync()
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateLogSoftmaxDescriptor(
handle, ctypes.byref(descriptor), y.descriptor, x.descriptor
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x.destroy_desc()
y.destroy_desc()
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetLogSoftmaxWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, x.device)
def lib_logsoftmax():
check_error(
LIBINFINIOP.infiniopLogSoftmax(
descriptor,
workspace.data(),
workspace_size.value,
y.data(),
x.data(),
None,
)
)
lib_logsoftmax()
if sync is not None:
sync()
# Use tolerance based on output dtype for mixed precision cases
atol, rtol = get_tolerance(_TOLERANCE_MAP, y_dtype)
# Ensure both tensors have the same dtype for comparison
y_tensor = y.actual_tensor()
if y_tensor.dtype != ans.dtype:
y_tensor = y_tensor.to(ans.dtype)
if DEBUG:
debug(y_tensor, ans, atol=atol, rtol=rtol)
assert torch.allclose(y_tensor, ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: logsoftmax(x.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_logsoftmax(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyLogSoftmaxDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
# Test standard cases (fp32 output)
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
# Test mixed precision cases
from libinfiniop import create_handle, destroy_handle, get_sync_func
handle = create_handle()
sync = get_sync_func(device)
try:
for x_dtype, y_dtype in _MIXED_PRECISION_CASES:
for shape, x_stride, y_stride, inplace in _TEST_CASES[
:5
]: # Test subset for mixed precision
test_mixed_precision(
handle,
device,
shape,
x_stride,
y_stride,
inplace,
x_dtype,
y_dtype,
sync,
)
finally:
destroy_handle(handle)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment