Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
db7e4076
Commit
db7e4076
authored
Feb 28, 2026
by
xgqdut2016
Committed by
wooway777
Feb 28, 2026
Browse files
issue/1032n: support strided last dim in cuda swiglu
parent
362f0187
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
28 deletions
+38
-28
src/infiniop/ops/swiglu/cuda/kernel_cuda.cuh
src/infiniop/ops/swiglu/cuda/kernel_cuda.cuh
+7
-6
src/infiniop/ops/swiglu/info.h
src/infiniop/ops/swiglu/info.h
+10
-4
src/infiniop/ops/swiglu/nvidia/swiglu_nvidia_cuda.cu
src/infiniop/ops/swiglu/nvidia/swiglu_nvidia_cuda.cu
+21
-18
No files found.
src/infiniop/ops/swiglu/cuda/kernel_cuda.cuh
View file @
db7e4076
...
...
@@ -29,17 +29,17 @@ __device__ void SwiGLUCudaKernel(
const
T
*
b
,
int
length
,
size_t
batch
,
size_t
seq_len
,
size_t
hidden_dim
,
ptrdiff_t
c_strides_0
,
ptrdiff_t
c_strides_1
,
ptrdiff_t
a_strides_0
,
ptrdiff_t
a_strides_1
,
ptrdiff_t
b_strides_0
,
ptrdiff_t
b_strides_1
)
{
ptrdiff_t
c_strides_0
,
ptrdiff_t
c_strides_1
,
ptrdiff_t
c_strides_2
,
ptrdiff_t
a_strides_0
,
ptrdiff_t
a_strides_1
,
ptrdiff_t
a_strides_2
,
ptrdiff_t
b_strides_0
,
ptrdiff_t
b_strides_1
,
ptrdiff_t
b_strides_2
)
{
int
ind_c
=
0
;
int
ind_a
=
0
;
int
ind_b
=
0
;
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
<
length
)
{
ind_c
+=
tid
%
(
int
)
hidden_dim
;
ind_a
+=
tid
%
(
int
)
hidden_dim
;
ind_b
+=
tid
%
(
int
)
hidden_dim
;
ind_c
+=
tid
%
(
int
)
hidden_dim
*
(
int
)
c_strides_2
;
ind_a
+=
tid
%
(
int
)
hidden_dim
*
(
int
)
a_strides_2
;
ind_b
+=
tid
%
(
int
)
hidden_dim
*
(
int
)
b_strides_2
;
tid
=
tid
/
(
int
)
hidden_dim
;
ind_c
+=
(
tid
%
(
int
)
seq_len
)
*
(
int
)
c_strides_1
;
ind_a
+=
(
tid
%
(
int
)
seq_len
)
*
(
int
)
a_strides_1
;
...
...
@@ -51,6 +51,7 @@ __device__ void SwiGLUCudaKernel(
T
gate
=
b
[
ind_b
];
T
up
=
a
[
ind_a
];
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
c
[
ind_c
]
=
__hmul2
(
__hmul2
(
gate
,
sigmoid
(
gate
)),
up
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
...
...
src/infiniop/ops/swiglu/info.h
View file @
db7e4076
...
...
@@ -14,9 +14,9 @@ public:
infiniDtype_t
dtype
;
size_t
length
;
size_t
batch
,
seq_len
,
hidden_dim
;
ptrdiff_t
c_strides_0
,
c_strides_1
;
ptrdiff_t
a_strides_0
,
a_strides_1
;
ptrdiff_t
b_strides_0
,
b_strides_1
;
ptrdiff_t
c_strides_0
,
c_strides_1
,
c_strides_2
;
ptrdiff_t
a_strides_0
,
a_strides_1
,
a_strides_2
;
ptrdiff_t
b_strides_0
,
b_strides_1
,
b_strides_2
;
static
utils
::
Result
<
SwiGLUCudaInfo
>
createSwiGLUCudaInfo
(
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
auto
dtype
=
c_desc
->
dtype
();
...
...
@@ -37,10 +37,13 @@ public:
ptrdiff_t
c_strides_0
=
(
ndim
==
3
?
c_desc
->
strides
()[
0
]
:
0
);
ptrdiff_t
c_strides_1
=
(
ndim
==
3
?
c_desc
->
strides
()[
1
]
:
c_desc
->
strides
()[
0
]);
ptrdiff_t
c_strides_2
=
(
ndim
==
3
?
c_desc
->
strides
()[
2
]
:
c_desc
->
strides
()[
1
]);
ptrdiff_t
a_strides_0
=
(
ndim
==
3
?
a_desc
->
strides
()[
0
]
:
0
);
ptrdiff_t
a_strides_1
=
(
ndim
==
3
?
a_desc
->
strides
()[
1
]
:
a_desc
->
strides
()[
0
]);
ptrdiff_t
a_strides_2
=
(
ndim
==
3
?
a_desc
->
strides
()[
2
]
:
a_desc
->
strides
()[
1
]);
ptrdiff_t
b_strides_0
=
(
ndim
==
3
?
b_desc
->
strides
()[
0
]
:
0
);
ptrdiff_t
b_strides_1
=
(
ndim
==
3
?
b_desc
->
strides
()[
1
]
:
b_desc
->
strides
()[
0
]);
ptrdiff_t
b_strides_2
=
(
ndim
==
3
?
b_desc
->
strides
()[
2
]
:
b_desc
->
strides
()[
1
]);
return
utils
::
Result
<
SwiGLUCudaInfo
>
(
SwiGLUCudaInfo
{
dtype
,
...
...
@@ -50,10 +53,13 @@ public:
hidden_dim
,
c_strides_0
,
c_strides_1
,
c_strides_2
,
a_strides_0
,
a_strides_1
,
a_strides_2
,
b_strides_0
,
b_strides_1
});
b_strides_1
,
b_strides_2
});
}
};
...
...
src/infiniop/ops/swiglu/nvidia/swiglu_nvidia_cuda.cu
View file @
db7e4076
...
...
@@ -10,13 +10,13 @@ INFINIOP_CUDA_KERNEL SwiGLUCuda(
const
T
*
b
,
int
length
,
size_t
batch
,
size_t
seq_len
,
size_t
hidden_dim
,
ptrdiff_t
c_strides_0
,
ptrdiff_t
c_strides_1
,
ptrdiff_t
a_strides_0
,
ptrdiff_t
a_strides_1
,
ptrdiff_t
b_strides_0
,
ptrdiff_t
b_strides_1
)
{
ptrdiff_t
c_strides_0
,
ptrdiff_t
c_strides_1
,
ptrdiff_t
c_strides_2
,
ptrdiff_t
a_strides_0
,
ptrdiff_t
a_strides_1
,
ptrdiff_t
a_strides_2
,
ptrdiff_t
b_strides_0
,
ptrdiff_t
b_strides_1
,
ptrdiff_t
b_strides_2
)
{
SwiGLUCudaKernel
<
T
,
BLOCK_SIZE
>
(
c
,
a
,
b
,
length
,
batch
,
seq_len
,
hidden_dim
,
c_strides_0
,
c_strides_1
,
a_strides_0
,
a_strides_1
,
b_strides_0
,
b_strides_1
);
c_strides_0
,
c_strides_1
,
c_strides_2
,
a_strides_0
,
a_strides_1
,
a_strides_2
,
b_strides_0
,
b_strides_1
,
b_strides_2
);
}
namespace
op
::
swiglu_cuda
::
nvidia
{
...
...
@@ -55,22 +55,25 @@ infiniStatus_t calculate_swiglu_cuda(
void
*
workspace
)
{
int
length
=
(
int
)
info
.
length
;
int
batch
=
(
int
)
info
.
batch
;
int
seq_len
=
(
int
)
info
.
seq_len
;
int
hidden_dim
=
(
int
)
info
.
hidden_dim
;
int
c_strides_0
=
(
int
)
info
.
c_strides_0
;
int
c_strides_1
=
(
int
)
info
.
c_strides_1
;
int
a_strides_0
=
(
int
)
info
.
a_strides_0
;
int
a_strides_1
=
(
int
)
info
.
a_strides_1
;
int
b_strides_0
=
(
int
)
info
.
b_strides_0
;
int
b_strides_1
=
(
int
)
info
.
b_strides_1
;
size_t
batch
=
info
.
batch
;
size_t
seq_len
=
info
.
seq_len
;
size_t
hidden_dim
=
info
.
hidden_dim
;
ptrdiff_t
c_strides_0
=
info
.
c_strides_0
;
ptrdiff_t
c_strides_1
=
info
.
c_strides_1
;
ptrdiff_t
c_strides_2
=
info
.
c_strides_2
;
ptrdiff_t
a_strides_0
=
info
.
a_strides_0
;
ptrdiff_t
a_strides_1
=
info
.
a_strides_1
;
ptrdiff_t
a_strides_2
=
info
.
a_strides_2
;
ptrdiff_t
b_strides_0
=
info
.
b_strides_0
;
ptrdiff_t
b_strides_1
=
info
.
b_strides_1
;
ptrdiff_t
b_strides_2
=
info
.
b_strides_2
;
int
num_blocks
=
(
length
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
SwiGLUCuda
<
T
,
BLOCK_SIZE
>
<<<
num_blocks
,
BLOCK_SIZE
,
0
,
stream
>>>
(
c
,
a
,
b
,
length
,
batch
,
seq_len
,
hidden_dim
,
c_strides_0
,
c_strides_1
,
a_strides_0
,
a_strides_1
,
b_strides_0
,
b_strides_1
);
c_strides_0
,
c_strides_1
,
c_strides_2
,
a_strides_0
,
a_strides_1
,
a_strides_2
,
b_strides_0
,
b_strides_1
,
b_strides_2
);
return
INFINI_STATUS_SUCCESS
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment