Commit db7e4076 authored by xgqdut2016's avatar xgqdut2016 Committed by wooway777
Browse files

issue/1032n: support strided last dim in cuda swiglu

parent 362f0187
......@@ -29,17 +29,17 @@ __device__ void SwiGLUCudaKernel(
const T *b,
int length,
size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1) {
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
int ind_c = 0;
int ind_a = 0;
int ind_b = 0;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < length) {
ind_c += tid % (int)hidden_dim;
ind_a += tid % (int)hidden_dim;
ind_b += tid % (int)hidden_dim;
ind_c += tid % (int)hidden_dim * (int)c_strides_2;
ind_a += tid % (int)hidden_dim * (int)a_strides_2;
ind_b += tid % (int)hidden_dim * (int)b_strides_2;
tid = tid / (int)hidden_dim;
ind_c += (tid % (int)seq_len) * (int)c_strides_1;
ind_a += (tid % (int)seq_len) * (int)a_strides_1;
......@@ -51,6 +51,7 @@ __device__ void SwiGLUCudaKernel(
T gate = b[ind_b];
T up = a[ind_a];
if constexpr (std::is_same_v<T, half2>) {
c[ind_c] = __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) {
......
......@@ -14,9 +14,9 @@ public:
infiniDtype_t dtype;
size_t length;
size_t batch, seq_len, hidden_dim;
ptrdiff_t c_strides_0, c_strides_1;
ptrdiff_t a_strides_0, a_strides_1;
ptrdiff_t b_strides_0, b_strides_1;
ptrdiff_t c_strides_0, c_strides_1, c_strides_2;
ptrdiff_t a_strides_0, a_strides_1, a_strides_2;
ptrdiff_t b_strides_0, b_strides_1, b_strides_2;
static utils::Result<SwiGLUCudaInfo> createSwiGLUCudaInfo(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) {
auto dtype = c_desc->dtype();
......@@ -37,10 +37,13 @@ public:
ptrdiff_t c_strides_0 = (ndim == 3 ? c_desc->strides()[0] : 0);
ptrdiff_t c_strides_1 = (ndim == 3 ? c_desc->strides()[1] : c_desc->strides()[0]);
ptrdiff_t c_strides_2 = (ndim == 3 ? c_desc->strides()[2] : c_desc->strides()[1]);
ptrdiff_t a_strides_0 = (ndim == 3 ? a_desc->strides()[0] : 0);
ptrdiff_t a_strides_1 = (ndim == 3 ? a_desc->strides()[1] : a_desc->strides()[0]);
ptrdiff_t a_strides_2 = (ndim == 3 ? a_desc->strides()[2] : a_desc->strides()[1]);
ptrdiff_t b_strides_0 = (ndim == 3 ? b_desc->strides()[0] : 0);
ptrdiff_t b_strides_1 = (ndim == 3 ? b_desc->strides()[1] : b_desc->strides()[0]);
ptrdiff_t b_strides_2 = (ndim == 3 ? b_desc->strides()[2] : b_desc->strides()[1]);
return utils::Result<SwiGLUCudaInfo>(SwiGLUCudaInfo{
dtype,
......@@ -50,10 +53,13 @@ public:
hidden_dim,
c_strides_0,
c_strides_1,
c_strides_2,
a_strides_0,
a_strides_1,
a_strides_2,
b_strides_0,
b_strides_1});
b_strides_1,
b_strides_2});
}
};
......
......@@ -10,13 +10,13 @@ INFINIOP_CUDA_KERNEL SwiGLUCuda(
const T *b,
int length,
size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1) {
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
SwiGLUCudaKernel<T, BLOCK_SIZE>(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1,
a_strides_0, a_strides_1,
b_strides_0, b_strides_1);
c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1, b_strides_2);
}
namespace op::swiglu_cuda::nvidia {
......@@ -55,22 +55,25 @@ infiniStatus_t calculate_swiglu_cuda(
void *workspace) {
int length = (int)info.length;
int batch = (int)info.batch;
int seq_len = (int)info.seq_len;
int hidden_dim = (int)info.hidden_dim;
int c_strides_0 = (int)info.c_strides_0;
int c_strides_1 = (int)info.c_strides_1;
int a_strides_0 = (int)info.a_strides_0;
int a_strides_1 = (int)info.a_strides_1;
int b_strides_0 = (int)info.b_strides_0;
int b_strides_1 = (int)info.b_strides_1;
size_t batch = info.batch;
size_t seq_len = info.seq_len;
size_t hidden_dim = info.hidden_dim;
ptrdiff_t c_strides_0 = info.c_strides_0;
ptrdiff_t c_strides_1 = info.c_strides_1;
ptrdiff_t c_strides_2 = info.c_strides_2;
ptrdiff_t a_strides_0 = info.a_strides_0;
ptrdiff_t a_strides_1 = info.a_strides_1;
ptrdiff_t a_strides_2 = info.a_strides_2;
ptrdiff_t b_strides_0 = info.b_strides_0;
ptrdiff_t b_strides_1 = info.b_strides_1;
ptrdiff_t b_strides_2 = info.b_strides_2;
int num_blocks = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
SwiGLUCuda<T, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1,
a_strides_0, a_strides_1,
b_strides_0, b_strides_1);
c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1, b_strides_2);
return INFINI_STATUS_SUCCESS;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment