Commit ae49716d authored by wooway777's avatar wooway777
Browse files

issue/1032 - follow nv changes on metax (swiglu)

parent db7e4076
...@@ -11,13 +11,13 @@ INFINIOP_METAX_KERNEL SwiGLUCuda( ...@@ -11,13 +11,13 @@ INFINIOP_METAX_KERNEL SwiGLUCuda(
const T *b, const T *b,
int length, int length,
size_t batch, size_t seq_len, size_t hidden_dim, size_t batch, size_t seq_len, size_t hidden_dim,
ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_0, ptrdiff_t c_strides_1, ptrdiff_t c_strides_2,
ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_0, ptrdiff_t a_strides_1, ptrdiff_t a_strides_2,
ptrdiff_t b_strides_0, ptrdiff_t b_strides_1) { ptrdiff_t b_strides_0, ptrdiff_t b_strides_1, ptrdiff_t b_strides_2) {
SwiGLUCudaKernel<T, BLOCK_SIZE>(c, a, b, length, batch, seq_len, hidden_dim, SwiGLUCudaKernel<T, BLOCK_SIZE>(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1); b_strides_0, b_strides_1, b_strides_2);
} }
namespace op::swiglu_cuda::metax { namespace op::swiglu_cuda::metax {
...@@ -56,22 +56,25 @@ infiniStatus_t calculate_swiglu_cuda( ...@@ -56,22 +56,25 @@ infiniStatus_t calculate_swiglu_cuda(
void *workspace) { void *workspace) {
int length = (int)info.length; int length = (int)info.length;
int batch = (int)info.batch; size_t batch = info.batch;
int seq_len = (int)info.seq_len; size_t seq_len = info.seq_len;
int hidden_dim = (int)info.hidden_dim; size_t hidden_dim = info.hidden_dim;
int c_strides_0 = (int)info.c_strides_0; ptrdiff_t c_strides_0 = info.c_strides_0;
int c_strides_1 = (int)info.c_strides_1; ptrdiff_t c_strides_1 = info.c_strides_1;
int a_strides_0 = (int)info.a_strides_0; ptrdiff_t c_strides_2 = info.c_strides_2;
int a_strides_1 = (int)info.a_strides_1; ptrdiff_t a_strides_0 = info.a_strides_0;
int b_strides_0 = (int)info.b_strides_0; ptrdiff_t a_strides_1 = info.a_strides_1;
int b_strides_1 = (int)info.b_strides_1; ptrdiff_t a_strides_2 = info.a_strides_2;
ptrdiff_t b_strides_0 = info.b_strides_0;
ptrdiff_t b_strides_1 = info.b_strides_1;
ptrdiff_t b_strides_2 = info.b_strides_2;
int num_blocks = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; int num_blocks = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
SwiGLUCuda<T, BLOCK_SIZE> SwiGLUCuda<T, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(c, a, b, length, batch, seq_len, hidden_dim, <<<num_blocks, BLOCK_SIZE, 0, stream>>>(c, a, b, length, batch, seq_len, hidden_dim,
c_strides_0, c_strides_1, c_strides_0, c_strides_1, c_strides_2,
a_strides_0, a_strides_1, a_strides_0, a_strides_1, a_strides_2,
b_strides_0, b_strides_1); b_strides_0, b_strides_1, b_strides_2);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment