Commit db7e4076 authored by xgqdut2016's avatar xgqdut2016 Committed by wooway777
Browse files

issue/1032n: support strided last dim in cuda swiglu

parent 362f0187
...@@ -29,17 +29,17 @@ __device__ void SwiGLUCudaKernel( ...@@ -29,17 +29,17 @@ __device__ void SwiGLUCudaKernel(
const T *b, const T *b,
int length, int length,
size_t batch, size_t seq_len, size_t hidden_dim, size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1) { ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
int ind_c = 0; int ind_c = 0;
int ind_a = 0; int ind_a = 0;
int ind_b = 0; int ind_b = 0;
int tid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < length) { if (tid < length) {
ind_c += tid % (int)hidden_dim; ind_c += tid % (int)hidden_dim * (int)c_strides_2;
ind_a += tid % (int)hidden_dim; ind_a += tid % (int)hidden_dim * (int)a_strides_2;
ind_b += tid % (int)hidden_dim; ind_b += tid % (int)hidden_dim * (int)b_strides_2;
tid = tid / (int)hidden_dim; tid = tid / (int)hidden_dim;
ind_c += (tid % (int)seq_len) * (int)c_strides_1; ind_c += (tid % (int)seq_len) * (int)c_strides_1;
ind_a += (tid % (int)seq_len) * (int)a_strides_1; ind_a += (tid % (int)seq_len) * (int)a_strides_1;
...@@ -51,6 +51,7 @@ __device__ void SwiGLUCudaKernel( ...@@ -51,6 +51,7 @@ __device__ void SwiGLUCudaKernel(
T gate = b[ind_b]; T gate = b[ind_b];
T up = a[ind_a]; T up = a[ind_a];
if constexpr (std::is_same_v<T, half2>) { if constexpr (std::is_same_v<T, half2>) {
c[ind_c] = __hmul2(__hmul2(gate, sigmoid(gate)), up); c[ind_c] = __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
......
...@@ -14,9 +14,9 @@ public: ...@@ -14,9 +14,9 @@ public:
infiniDtype_t dtype; infiniDtype_t dtype;
size_t length; size_t length;
size_t batch, seq_len, hidden_dim; size_t batch, seq_len, hidden_dim;
ptrdiff_t c_strides_0, c_strides_1; ptrdiff_t c_strides_0, c_strides_1, c_strides_2;
ptrdiff_t a_strides_0, a_strides_1; ptrdiff_t a_strides_0, a_strides_1, a_strides_2;
ptrdiff_t b_strides_0, b_strides_1; ptrdiff_t b_strides_0, b_strides_1, b_strides_2;
static utils::Result<SwiGLUCudaInfo> createSwiGLUCudaInfo(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) { static utils::Result<SwiGLUCudaInfo> createSwiGLUCudaInfo(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) {
auto dtype = c_desc->dtype(); auto dtype = c_desc->dtype();
...@@ -37,10 +37,13 @@ public: ...@@ -37,10 +37,13 @@ public:
ptrdiff_t c_strides_0 = (ndim == 3 ? c_desc->strides()[0] : 0); ptrdiff_t c_strides_0 = (ndim == 3 ? c_desc->strides()[0] : 0);
ptrdiff_t c_strides_1 = (ndim == 3 ? c_desc->strides()[1] : c_desc->strides()[0]); ptrdiff_t c_strides_1 = (ndim == 3 ? c_desc->strides()[1] : c_desc->strides()[0]);
ptrdiff_t c_strides_2 = (ndim == 3 ? c_desc->strides()[2] : c_desc->strides()[1]);
ptrdiff_t a_strides_0 = (ndim == 3 ? a_desc->strides()[0] : 0); ptrdiff_t a_strides_0 = (ndim == 3 ? a_desc->strides()[0] : 0);
ptrdiff_t a_strides_1 = (ndim == 3 ? a_desc->strides()[1] : a_desc->strides()[0]); ptrdiff_t a_strides_1 = (ndim == 3 ? a_desc->strides()[1] : a_desc->strides()[0]);
ptrdiff_t a_strides_2 = (ndim == 3 ? a_desc->strides()[2] : a_desc->strides()[1]);
ptrdiff_t b_strides_0 = (ndim == 3 ? b_desc->strides()[0] : 0); ptrdiff_t b_strides_0 = (ndim == 3 ? b_desc->strides()[0] : 0);
ptrdiff_t b_strides_1 = (ndim == 3 ? b_desc->strides()[1] : b_desc->strides()[0]); ptrdiff_t b_strides_1 = (ndim == 3 ? b_desc->strides()[1] : b_desc->strides()[0]);
ptrdiff_t b_strides_2 = (ndim == 3 ? b_desc->strides()[2] : b_desc->strides()[1]);
return utils::Result<SwiGLUCudaInfo>(SwiGLUCudaInfo{ return utils::Result<SwiGLUCudaInfo>(SwiGLUCudaInfo{
dtype, dtype,
...@@ -50,10 +53,13 @@ public: ...@@ -50,10 +53,13 @@ public:
hidden_dim, hidden_dim,
c_strides_0, c_strides_0,
c_strides_1, c_strides_1,
c_strides_2,
a_strides_0, a_strides_0,
a_strides_1, a_strides_1,
a_strides_2,
b_strides_0, b_strides_0,
b_strides_1}); b_strides_1,
b_strides_2});
} }
}; };
......
...@@ -10,13 +10,13 @@ INFINIOP_CUDA_KERNEL SwiGLUCuda( ...@@ -10,13 +10,13 @@ INFINIOP_CUDA_KERNEL SwiGLUCuda(
const T *b, const T *b,
int length, int length,
size_t batch, size_t seq_len, size_t hidden_dim, size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1) { ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
SwiGLUCudaKernel<T, BLOCK_SIZE>(c, a, b, length, batch, seq_len, hidden_dim, SwiGLUCudaKernel<T, BLOCK_SIZE>(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1); b_strides_0, b_strides_1, b_strides_2);
} }
namespace op::swiglu_cuda::nvidia { namespace op::swiglu_cuda::nvidia {
...@@ -55,22 +55,25 @@ infiniStatus_t calculate_swiglu_cuda( ...@@ -55,22 +55,25 @@ infiniStatus_t calculate_swiglu_cuda(
void *workspace) { void *workspace) {
int length = (int)info.length; int length = (int)info.length;
int batch = (int)info.batch; size_t batch = info.batch;
int seq_len = (int)info.seq_len; size_t seq_len = info.seq_len;
int hidden_dim = (int)info.hidden_dim; size_t hidden_dim = info.hidden_dim;
int c_strides_0 = (int)info.c_strides_0; ptrdiff_t c_strides_0 = info.c_strides_0;
int c_strides_1 = (int)info.c_strides_1; ptrdiff_t c_strides_1 = info.c_strides_1;
int a_strides_0 = (int)info.a_strides_0; ptrdiff_t c_strides_2 = info.c_strides_2;
int a_strides_1 = (int)info.a_strides_1; ptrdiff_t a_strides_0 = info.a_strides_0;
int b_strides_0 = (int)info.b_strides_0; ptrdiff_t a_strides_1 = info.a_strides_1;
int b_strides_1 = (int)info.b_strides_1; ptrdiff_t a_strides_2 = info.a_strides_2;
ptrdiff_t b_strides_0 = info.b_strides_0;
ptrdiff_t b_strides_1 = info.b_strides_1;
ptrdiff_t b_strides_2 = info.b_strides_2;
int num_blocks = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; int num_blocks = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
SwiGLUCuda<T, BLOCK_SIZE> SwiGLUCuda<T, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(c, a, b, length, batch, seq_len, hidden_dim, <<<num_blocks, BLOCK_SIZE, 0, stream>>>(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1); b_strides_0, b_strides_1, b_strides_2);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment