Commit 063beaec authored by PanZezhong's avatar PanZezhong
Browse files

issue/4 更新接口

parent 6b717b30
......@@ -16,8 +16,9 @@ Descriptor::~Descriptor() {
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc) {
auto info = CausalSoftmaxInfo::create(y_desc);
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::cuda::Handle *>(handle)->internal()},
......@@ -26,14 +27,24 @@ infiniStatus_t Descriptor::create(
}
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(void *data, infiniDtype_t dtype, size_t batch_size, size_t seq_len, size_t total_seq_len, ptrdiff_t stride_b, ptrdiff_t stride_i, cudaStream_t stream) {
infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
size_t batch_size, size_t seq_len, size_t total_seq_len,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_i,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
cudaStream_t stream) {
dim3 grid(uint32_t(seq_len), uint32_t(batch_size), 1);
if (dtype == INFINI_DTYPE_F16) {
causalSoftmax<BLOCK_SIZE, half, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((half *)data, batch_size, seq_len, total_seq_len, stride_b, stride_i);
<<<grid, BLOCK_SIZE, 0, stream>>>((half *)y, (const half *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else if (dtype == INFINI_DTYPE_F32) {
causalSoftmax<BLOCK_SIZE, float, float>
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)data, batch_size, seq_len, total_seq_len, stride_b, stride_i);
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
batch_size, seq_len, total_seq_len,
y_stride_b, y_stride_i,
x_stride_b, x_stride_i);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -41,15 +52,18 @@ infiniStatus_t launchKernel(void *data, infiniDtype_t dtype, size_t batch_size,
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *data,
void *y,
const void *x,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(
data, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, _info.stride_b, _info.stride_i, stream));
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(
data, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, _info.stride_b, _info.stride_i, stream));
y, x, _info.dtype, _info.batch_size, _info.seq_len, _info.total_seq_len,
_info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
......
......@@ -4,14 +4,20 @@
#include "../../../devices/cuda/cuda_common.cuh"
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
INFINIOP_CUDA_KERNEL causalSoftmax(Tdata *data_, size_t batch, size_t height, size_t width, ptrdiff_t stride_b, ptrdiff_t stride_h) {
Tdata *data = data_ // threadIdx.x for col_id
+ blockIdx.y * stride_b // gridDim.y for batch_id
+ blockIdx.x * stride_h; // gridDim.x for row_id
INFINIOP_CUDA_KERNEL causalSoftmax(
Tdata *y_, const Tdata *x_,
size_t batch, size_t height, size_t width,
ptrdiff_t y_stride_b, ptrdiff_t y_stride_h,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_h) {
Tdata *y = y_ // threadIdx.x for col_id
+ blockIdx.y * y_stride_b // gridDim.y for batch_id
+ blockIdx.x * y_stride_h; // gridDim.x for row_id
const Tdata *x = x_ + blockIdx.y * x_stride_b + blockIdx.x * x_stride_h;
// [Reduce] Find max value in each row and store in shared memory
__shared__ Tdata max_;
Tdata max_0 = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(data, width);
Tdata max_0 = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(x, width);
if (threadIdx.x == 0) {
max_ = max_0;
}
......@@ -25,16 +31,16 @@ INFINIOP_CUDA_KERNEL causalSoftmax(Tdata *data_, size_t batch, size_t height, si
// 2 | * * * ... * * * |
// height: 3 col_id->
if (width + blockIdx.x >= threadIdx.x + height) {
data[col] = exp(data[col] - max_);
y[col] = exp(x[col] - max_);
} else {
data[col] = Tdata(0);
y[col] = Tdata(0);
}
}
__syncthreads();
// [Reduce] Find the sum of each updated row and store in shared memory
__shared__ Tcompute sum_;
Tcompute sum_0 = op::common_cuda::reduce_op::sum<BLOCK_SIZE, Tdata, Tcompute>(data, width);
Tcompute sum_0 = op::common_cuda::reduce_op::sum<BLOCK_SIZE, Tdata, Tcompute>(y, width);
if (threadIdx.x == 0) {
sum_ = sum_0;
}
......@@ -42,7 +48,7 @@ INFINIOP_CUDA_KERNEL causalSoftmax(Tdata *data_, size_t batch, size_t height, si
// [Elementwise] Divide each element by the sum and store in shared memory
for (size_t col = threadIdx.x; col < width; col += BLOCK_SIZE) {
data[col] /= Tdata(sum_);
y[col] /= Tdata(sum_);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment