Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
71cac971
Commit
71cac971
authored
Apr 23, 2026
by
sunchao_0511
Browse files
remove rope debug && swiglu optim
parent
a1937618
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
205 additions
and
57 deletions
+205
-57
src/infiniop/ops/rope/nvidia/rope_nvidia.cu
src/infiniop/ops/rope/nvidia/rope_nvidia.cu
+0
-15
src/infiniop/ops/swiglu/cuda/kernel.cuh
src/infiniop/ops/swiglu/cuda/kernel.cuh
+137
-39
src/infiniop/ops/swiglu/nvidia/swiglu_nvidia_cuda.cu
src/infiniop/ops/swiglu/nvidia/swiglu_nvidia_cuda.cu
+68
-3
No files found.
src/infiniop/ops/rope/nvidia/rope_nvidia.cu
View file @
71cac971
...
...
@@ -126,13 +126,6 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
// 3D tensors: use 2D grid [seqlen, nhead], batch dimension is 1
grid_dim
=
dim3
(
dimx
,
dimy
,
1
);
}
// printf("block_size = %d info.table_dim = %ld has_batch_dim: %d, is_gpt_j: %d pos_has_batch_dim: %d\n",
// block_size, info.table_dim, info.has_batch_dim, is_gpt_j, info.pos_has_batch_dim);
// [batch, seqlen, nhead, dhead, table_len, table_dim, y_stride_batch, y_stride_seqlen, y_stride_nhead, x_stride_batch, x_stride_seqlen,x_stride_nhead]
// printf("[%ld %ld %ld %ld %ld %ld %ld %ld %ld %ld %ld %ld]\n", info.batch,
// info.seqlen, info.nhead, info.dhead, info.table_len, info.table_dim,
// info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
// info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
if
(
is_gpt_j
)
{
ropeThreadPerItemKernel
<
true
><<<
grid_dim
,
nthreads
,
0
,
stream
>>>
(
y
,
x
,
pos_ids
,
sin_table
,
cos_table
,
info
.
table_dim
,
...
...
@@ -154,14 +147,6 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
info
.
has_batch_dim
,
info
.
y_stride_batch
,
info
.
y_stride_seqlen
,
info
.
y_stride_nhead
,
info
.
x_stride_batch
,
info
.
x_stride_seqlen
,
info
.
x_stride_nhead
);
// ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
// y, x, pos_ids, sin_table, cos_table, info.table_dim,
// pos_stride_batch,
// info.pos_has_batch_dim,
// info.has_batch_dim,
// info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
// info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
}
else
{
ropeThreadPerItemKernel
<
false
><<<
grid_dim
,
nthreads
,
0
,
stream
>>>
(
y
,
x
,
pos_ids
,
sin_table
,
cos_table
,
info
.
table_dim
,
...
...
src/infiniop/ops/swiglu/cuda/kernel.cuh
View file @
71cac971
#ifndef __SWIGLU_CUDA_H__
#define __SWIGLU_CUDA_H__
namespace
op
::
swiglu
::
cuda
{
typedef
struct
SwiGLUOp
{
private:
template
<
typename
T
>
__device__
__forceinline__
T
sigmoid
(
const
T
&
x
)
const
{
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
return
h2rcp
(
__hadd2
(
make_half2
(
1
,
1
),
h2exp
(
__hneg2
(
x
))));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
return
hrcp
(
__hadd
(
half
(
1.
f
),
__float2half
(
__expf
(
__half2float
(
__hneg
(
x
))))));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat162
>
)
{
float
x0
=
__bfloat162float
(
__low2bfloat16
(
x
));
float
x1
=
__bfloat162float
(
__high2bfloat16
(
x
));
float
sig0
=
__frcp_rn
(
__fadd_rn
(
1.0
f
,
__expf
(
-
x0
)));
float
sig1
=
__frcp_rn
(
__fadd_rn
(
1.0
f
,
__expf
(
-
x1
)));
return
__floats2bfloat162_rn
(
sig0
,
sig1
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat16
>
)
{
float
xf
=
__bfloat162float
(
x
);
return
__float2bfloat16_rn
(
__frcp_rn
(
__fadd_rn
(
1.0
f
,
__expf
(
-
xf
))));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
__frcp_rn
(
__fadd_rn
(
1
,
__expf
(
-
x
)));
}
else
{
return
1
/
(
1
+
std
::
exp
(
-
x
));
}
#ifndef __SWIGLU_CUDA_KERNEL_CUH__
#define __SWIGLU_CUDA_KERNEL_CUH__
template
<
typename
T
>
__device__
__forceinline__
T
sigmoid
(
const
T
&
x
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
return
h2rcp
(
__hadd2
(
make_half2
(
1
,
1
),
h2exp
(
__hneg2
(
x
))));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
return
hrcp
(
__hadd
(
half
(
1.
f
),
__float2half
(
__expf
(
__half2float
(
__hneg
(
x
))))));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat162
>
)
{
float
x0
=
__bfloat162float
(
__low2bfloat16
(
x
));
float
x1
=
__bfloat162float
(
__high2bfloat16
(
x
));
float
sig0
=
__frcp_rn
(
__fadd_rn
(
1.0
f
,
__expf
(
-
x0
)));
float
sig1
=
__frcp_rn
(
__fadd_rn
(
1.0
f
,
__expf
(
-
x1
)));
return
__floats2bfloat162_rn
(
sig0
,
sig1
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat16
>
)
{
float
xf
=
__bfloat162float
(
x
);
return
__float2bfloat16_rn
(
__frcp_rn
(
__fadd_rn
(
1.0
f
,
__expf
(
-
xf
))));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
__frcp_rn
(
__fadd_rn
(
1
,
__expf
(
-
x
)));
}
else
{
return
1
/
(
1
+
std
::
exp
(
-
x
));
}
}
template
<
typename
T
,
unsigned
int
BLOCK_SIZE
>
__device__
void
SwiGLUCudaKernel
(
T
*
c
,
const
T
*
a
,
const
T
*
b
,
int
length
,
size_t
batch
,
size_t
seq_len
,
size_t
hidden_dim
,
ptrdiff_t
c_strides_0
,
ptrdiff_t
c_strides_1
,
ptrdiff_t
c_strides_2
,
ptrdiff_t
a_strides_0
,
ptrdiff_t
a_strides_1
,
ptrdiff_t
a_strides_2
,
ptrdiff_t
b_strides_0
,
ptrdiff_t
b_strides_1
,
ptrdiff_t
b_strides_2
)
{
int
ind_c
=
0
;
int
ind_a
=
0
;
int
ind_b
=
0
;
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
<
length
)
{
ind_c
+=
tid
%
(
int
)
hidden_dim
*
(
int
)
c_strides_2
;
ind_a
+=
tid
%
(
int
)
hidden_dim
*
(
int
)
a_strides_2
;
ind_b
+=
tid
%
(
int
)
hidden_dim
*
(
int
)
b_strides_2
;
tid
=
tid
/
(
int
)
hidden_dim
;
ind_c
+=
(
tid
%
(
int
)
seq_len
)
*
(
int
)
c_strides_1
;
ind_a
+=
(
tid
%
(
int
)
seq_len
)
*
(
int
)
a_strides_1
;
ind_b
+=
(
tid
%
(
int
)
seq_len
)
*
(
int
)
b_strides_1
;
tid
=
tid
/
(
int
)
seq_len
;
ind_c
+=
(
tid
%
(
int
)
batch
)
*
(
int
)
c_strides_0
;
ind_a
+=
(
tid
%
(
int
)
batch
)
*
(
int
)
a_strides_0
;
ind_b
+=
(
tid
%
(
int
)
batch
)
*
(
int
)
b_strides_0
;
T
gate
=
b
[
ind_b
];
T
up
=
a
[
ind_a
];
public:
static
constexpr
size_t
num_inputs
=
2
;
template
<
typename
T
>
__device__
__forceinline__
T
operator
()(
const
T
&
up
,
const
T
&
gate
)
const
{
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
return
__hmul2
(
__hmul2
(
gate
,
sigmoid
(
gate
)),
up
);
c
[
ind_c
]
=
__hmul2
(
__hmul2
(
gate
,
sigmoid
(
gate
)),
up
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
return
__hmul
(
__hmul
(
gate
,
sigmoid
(
gate
)),
up
);
c
[
ind_c
]
=
__hmul
(
__hmul
(
gate
,
sigmoid
(
gate
)),
up
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat162
>
)
{
cuda_bfloat162
sig
=
sigmoid
(
gate
);
float
gate0
=
__bfloat162float
(
__low2bfloat16
(
gate
));
...
...
@@ -44,20 +66,96 @@ public:
float
up1
=
__bfloat162float
(
__high2bfloat16
(
up
));
float
res0
=
__fmul_rn
(
__fmul_rn
(
gate0
,
sig0
),
up0
);
float
res1
=
__fmul_rn
(
__fmul_rn
(
gate1
,
sig1
),
up1
);
return
__floats2bfloat162_rn
(
res0
,
res1
);
c
[
ind_c
]
=
__floats2bfloat162_rn
(
res0
,
res1
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat16
>
)
{
cuda_bfloat16
sig
=
sigmoid
(
gate
);
float
gatef
=
__bfloat162float
(
gate
);
float
sigf
=
__bfloat162float
(
sig
);
float
upf
=
__bfloat162float
(
up
);
return
__float2bfloat16_rn
(
__fmul_rn
(
__fmul_rn
(
gatef
,
sigf
),
upf
));
c
[
ind_c
]
=
__float2bfloat16_rn
(
__fmul_rn
(
__fmul_rn
(
gatef
,
sigf
),
upf
));
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
__fmul_rn
(
__fmul_rn
(
gate
,
sigmoid
(
gate
)),
up
);
c
[
ind_c
]
=
__fmul_rn
(
__fmul_rn
(
gate
,
sigmoid
(
gate
)),
up
);
}
else
{
return
gate
*
sigmoid
(
gate
)
*
up
;
c
[
ind_c
]
=
gate
*
sigmoid
(
gate
)
*
up
;
}
}
}
SwiGLUOp
;
}
// namespace op::swiglu::cuda
}
__device__
void
CustomSwiGLUCudaKernel
(
__nv_bfloat16
*
c
,
const
__nv_bfloat16
*
a
,
const
__nv_bfloat16
*
b
,
int
length
,
size_t
batch
,
size_t
seq_len
,
size_t
hidden_dim
,
ptrdiff_t
c_strides_0
,
ptrdiff_t
c_strides_1
,
ptrdiff_t
c_strides_2
,
ptrdiff_t
a_strides_0
,
ptrdiff_t
a_strides_1
,
ptrdiff_t
a_strides_2
,
ptrdiff_t
b_strides_0
,
ptrdiff_t
b_strides_1
,
ptrdiff_t
b_strides_2
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
length
)
return
;
int
batchIdx
=
tid
/
(
seq_len
*
hidden_dim
);
int
seqIdx
=
(
tid
-
batchIdx
*
seq_len
*
hidden_dim
)
/
hidden_dim
;
int
hiddenIdx
=
tid
-
(
batchIdx
*
seq_len
*
hidden_dim
+
seqIdx
*
hidden_dim
);
int
ind_c
=
tid
;
int
ind_b
=
batchIdx
*
b_strides_0
+
seqIdx
*
b_strides_1
+
hiddenIdx
*
b_strides_2
;
int
ind_a
=
ind_b
;
__nv_bfloat16
gate
=
b
[
ind_b
];
__nv_bfloat16
up
=
a
[
ind_a
];
float
xf
=
__bfloat162float
(
gate
);
cuda_bfloat16
sig
=
__float2bfloat16_rn
(
__frcp_rn
(
__fadd_rn
(
1.0
f
,
__expf
(
-
xf
))));
float
gatef
=
__bfloat162float
(
gate
);
float
sigf
=
__bfloat162float
(
sig
);
float
upf
=
__bfloat162float
(
up
);
c
[
ind_c
]
=
__float2bfloat16_rn
(
__fmul_rn
(
__fmul_rn
(
gatef
,
sigf
),
upf
));
}
__device__
void
CustomVecSwiGLUCudaKernel
(
__nv_bfloat16
*
c
,
const
__nv_bfloat16
*
a
,
const
__nv_bfloat16
*
b
,
int
length
,
size_t
batch
,
size_t
seq_len
,
size_t
hidden_dim
,
ptrdiff_t
c_strides_0
,
ptrdiff_t
c_strides_1
,
ptrdiff_t
c_strides_2
,
ptrdiff_t
a_strides_0
,
ptrdiff_t
a_strides_1
,
ptrdiff_t
a_strides_2
,
ptrdiff_t
b_strides_0
,
ptrdiff_t
b_strides_1
,
ptrdiff_t
b_strides_2
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
length
)
return
;
int
batchIdx
=
tid
/
(
seq_len
*
hidden_dim
);
int
seqIdx
=
(
tid
-
batchIdx
*
seq_len
*
hidden_dim
)
/
hidden_dim
;
int
hiddenIdx
=
tid
-
(
batchIdx
*
seq_len
*
hidden_dim
+
seqIdx
*
hidden_dim
);
// int ind_c = (batchIdx * c_strides_0 + seqIdx * c_strides_1 + hiddenIdx * c_strides_2) << 3;
int
ind_c
=
tid
<<
3
;
int
ind_b
=
(
batchIdx
*
b_strides_0
+
seqIdx
*
b_strides_1
+
hiddenIdx
*
b_strides_2
)
<<
3
;
int
ind_a
=
ind_b
;
__nv_bfloat16
gate
[
8
];
__nv_bfloat16
up
[
8
];
__nv_bfloat16
output
[
8
];
const
float4
*
global_gate
=
reinterpret_cast
<
const
float4
*>
(
b
+
ind_b
);
const
float4
*
global_up
=
reinterpret_cast
<
const
float4
*>
(
a
+
ind_a
);
float4
*
global_output
=
reinterpret_cast
<
float4
*>
(
c
+
ind_c
);
float4
gate_val
=
*
global_gate
;
float4
up_val
=
*
global_up
;
*
reinterpret_cast
<
float4
*>
(
gate
)
=
gate_val
;
*
reinterpret_cast
<
float4
*>
(
up
)
=
up_val
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
float
xf
=
__bfloat162float
(
gate
[
i
]);
__nv_bfloat16
sig
=
__float2bfloat16_rn
(
__frcp_rn
(
__fadd_rn
(
1.0
f
,
__expf
(
-
xf
))));
float
gatef
=
__bfloat162float
(
gate
[
i
]);
float
sigf
=
__bfloat162float
(
sig
);
float
upf
=
__bfloat162float
(
up
[
i
]);
output
[
i
]
=
__float2bfloat16_rn
(
__fmul_rn
(
__fmul_rn
(
gatef
,
sigf
),
upf
));
}
*
global_output
=
*
reinterpret_cast
<
float4
*>
(
output
);
}
#endif // __SWIGLU_CUDA_H__
#endif // __SWIGLU_CUDA_
KERNEL_CU
H__
src/infiniop/ops/swiglu/nvidia/swiglu_nvidia_cuda.cu
View file @
71cac971
...
...
@@ -19,6 +19,45 @@ INFINIOP_CUDA_KERNEL SwiGLUCuda(
b_strides_0
,
b_strides_1
,
b_strides_2
);
}
INFINIOP_CUDA_KERNEL
CustomSwiGLUCuda
(
__nv_bfloat16
*
c
,
const
__nv_bfloat16
*
a
,
const
__nv_bfloat16
*
b
,
int
length
,
size_t
batch
,
size_t
seq_len
,
size_t
hidden_dim
,
ptrdiff_t
c_strides_0
,
ptrdiff_t
c_strides_1
,
ptrdiff_t
c_strides_2
,
ptrdiff_t
a_strides_0
,
ptrdiff_t
a_strides_1
,
ptrdiff_t
a_strides_2
,
ptrdiff_t
b_strides_0
,
ptrdiff_t
b_strides_1
,
ptrdiff_t
b_strides_2
)
{
CustomSwiGLUCudaKernel
(
c
,
a
,
b
,
length
,
batch
,
seq_len
,
hidden_dim
,
c_strides_0
,
c_strides_1
,
c_strides_2
,
a_strides_0
,
a_strides_1
,
a_strides_2
,
b_strides_0
,
b_strides_1
,
b_strides_2
);
}
INFINIOP_CUDA_KERNEL
CustomVecSwiGLUCuda
(
__nv_bfloat16
*
c
,
const
__nv_bfloat16
*
a
,
const
__nv_bfloat16
*
b
,
int
length
,
size_t
batch
,
size_t
seq_len
,
size_t
hidden_dim
,
ptrdiff_t
c_strides_0
,
ptrdiff_t
c_strides_1
,
ptrdiff_t
c_strides_2
,
ptrdiff_t
a_strides_0
,
ptrdiff_t
a_strides_1
,
ptrdiff_t
a_strides_2
,
ptrdiff_t
b_strides_0
,
ptrdiff_t
b_strides_1
,
ptrdiff_t
b_strides_2
)
{
const
int
VEC_FACTOR
=
8
;
hidden_dim
/=
VEC_FACTOR
;
c_strides_0
/=
VEC_FACTOR
;
c_strides_1
/=
VEC_FACTOR
;
a_strides_0
/=
VEC_FACTOR
;
a_strides_1
/=
VEC_FACTOR
;
b_strides_0
/=
VEC_FACTOR
;
b_strides_1
/=
VEC_FACTOR
;
length
/=
VEC_FACTOR
;
CustomVecSwiGLUCudaKernel
(
c
,
a
,
b
,
length
,
batch
,
seq_len
,
hidden_dim
,
c_strides_0
,
c_strides_1
,
c_strides_2
,
a_strides_0
,
a_strides_1
,
a_strides_2
,
b_strides_0
,
b_strides_1
,
b_strides_2
);
}
namespace
op
::
swiglu_cuda
::
nvidia
{
struct
Descriptor
::
Opaque
{
...
...
@@ -68,13 +107,39 @@ infiniStatus_t calculate_swiglu_cuda(
ptrdiff_t
b_strides_1
=
info
.
b_strides_1
;
ptrdiff_t
b_strides_2
=
info
.
b_strides_2
;
int
num_blocks
=
(
length
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
SwiGLUCuda
<
T
,
BLOCK_SIZE
>
bool
vec_flag
=
false
;
//向量化取数据在这个长度下性能才最优
if
((
hidden_dim
%
8
==
0
)
&&
length
>=
295680
)
{
vec_flag
=
true
;
}
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
auto
bf16_c
=
reinterpret_cast
<
__nv_bfloat16
*>
(
c
);
auto
bf16_a
=
reinterpret_cast
<
const
__nv_bfloat16
*>
(
a
);
auto
bf16_b
=
reinterpret_cast
<
const
__nv_bfloat16
*>
(
b
);
if
(
vec_flag
)
{
int
block_size
=
256
;
int
grid_size
=
(
length
/
8
+
block_size
-
1
)
/
block_size
;
CustomVecSwiGLUCuda
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
bf16_c
,
bf16_a
,
bf16_b
,
length
,
batch
,
seq_len
,
hidden_dim
,
c_strides_0
,
c_strides_1
,
c_strides_2
,
a_strides_0
,
a_strides_1
,
a_strides_2
,
b_strides_0
,
b_strides_1
,
b_strides_2
);
}
else
{
int
block_size
=
256
;
int
grid_size
=
(
length
+
block_size
-
1
)
/
block_size
;
CustomSwiGLUCuda
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
bf16_c
,
bf16_a
,
bf16_b
,
length
,
batch
,
seq_len
,
hidden_dim
,
c_strides_0
,
c_strides_1
,
c_strides_2
,
a_strides_0
,
a_strides_1
,
a_strides_2
,
b_strides_0
,
b_strides_1
,
b_strides_2
);
}
}
else
{
int
num_blocks
=
(
length
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
SwiGLUCuda
<
T
,
BLOCK_SIZE
>
<<<
num_blocks
,
BLOCK_SIZE
,
0
,
stream
>>>
(
c
,
a
,
b
,
length
,
batch
,
seq_len
,
hidden_dim
,
c_strides_0
,
c_strides_1
,
c_strides_2
,
a_strides_0
,
a_strides_1
,
a_strides_2
,
b_strides_0
,
b_strides_1
,
b_strides_2
);
}
return
INFINI_STATUS_SUCCESS
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment