Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
063beaec
Commit
063beaec
authored
Apr 09, 2025
by
PanZezhong
Browse files
issue/4 更新接口
parent
6b717b30
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
17 deletions
+37
-17
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cu
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cu
+22
-8
src/infiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
...nfiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
+15
-9
No files found.
src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cu
View file @
063beaec
...
...
@@ -16,8 +16,9 @@ Descriptor::~Descriptor() {
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
)
{
auto
info
=
CausalSoftmaxInfo
::
create
(
y_desc
);
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
auto
info
=
CausalSoftmaxInfo
::
create
(
y_desc
,
x_desc
);
CHECK_RESULT
(
info
);
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
cuda
::
Handle
*>
(
handle
)
->
internal
()},
...
...
@@ -26,14 +27,24 @@ infiniStatus_t Descriptor::create(
}
template
<
unsigned
int
BLOCK_SIZE
>
infiniStatus_t
launchKernel
(
void
*
data
,
infiniDtype_t
dtype
,
size_t
batch_size
,
size_t
seq_len
,
size_t
total_seq_len
,
ptrdiff_t
stride_b
,
ptrdiff_t
stride_i
,
cudaStream_t
stream
)
{
infiniStatus_t
launchKernel
(
void
*
y
,
const
void
*
x
,
infiniDtype_t
dtype
,
size_t
batch_size
,
size_t
seq_len
,
size_t
total_seq_len
,
ptrdiff_t
y_stride_b
,
ptrdiff_t
y_stride_i
,
ptrdiff_t
x_stride_b
,
ptrdiff_t
x_stride_i
,
cudaStream_t
stream
)
{
dim3
grid
(
uint32_t
(
seq_len
),
uint32_t
(
batch_size
),
1
);
if
(
dtype
==
INFINI_DTYPE_F16
)
{
causalSoftmax
<
BLOCK_SIZE
,
half
,
float
>
<<<
grid
,
BLOCK_SIZE
,
0
,
stream
>>>
((
half
*
)
data
,
batch_size
,
seq_len
,
total_seq_len
,
stride_b
,
stride_i
);
<<<
grid
,
BLOCK_SIZE
,
0
,
stream
>>>
((
half
*
)
y
,
(
const
half
*
)
x
,
batch_size
,
seq_len
,
total_seq_len
,
y_stride_b
,
y_stride_i
,
x_stride_b
,
x_stride_i
);
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
causalSoftmax
<
BLOCK_SIZE
,
float
,
float
>
<<<
grid
,
BLOCK_SIZE
,
0
,
stream
>>>
((
float
*
)
data
,
batch_size
,
seq_len
,
total_seq_len
,
stride_b
,
stride_i
);
<<<
grid
,
BLOCK_SIZE
,
0
,
stream
>>>
((
float
*
)
y
,
(
const
float
*
)
x
,
batch_size
,
seq_len
,
total_seq_len
,
y_stride_b
,
y_stride_i
,
x_stride_b
,
x_stride_i
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
...
...
@@ -41,15 +52,18 @@ infiniStatus_t launchKernel(void *data, infiniDtype_t dtype, size_t batch_size,
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
data
,
void
*
y
,
const
void
*
x
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
data
,
_info
.
dtype
,
_info
.
batch_size
,
_info
.
seq_len
,
_info
.
total_seq_len
,
_info
.
stride_b
,
_info
.
stride_i
,
stream
));
y
,
x
,
_info
.
dtype
,
_info
.
batch_size
,
_info
.
seq_len
,
_info
.
total_seq_len
,
_info
.
y_stride_b
,
_info
.
y_stride_i
,
_info
.
x_stride_b
,
_info
.
x_stride_i
,
stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
data
,
_info
.
dtype
,
_info
.
batch_size
,
_info
.
seq_len
,
_info
.
total_seq_len
,
_info
.
stride_b
,
_info
.
stride_i
,
stream
));
y
,
x
,
_info
.
dtype
,
_info
.
batch_size
,
_info
.
seq_len
,
_info
.
total_seq_len
,
_info
.
y_stride_b
,
_info
.
y_stride_i
,
_info
.
x_stride_b
,
_info
.
x_stride_i
,
stream
));
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
...
...
src/infiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh
View file @
063beaec
...
...
@@ -4,14 +4,20 @@
#include "../../../devices/cuda/cuda_common.cuh"
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tcompute
>
INFINIOP_CUDA_KERNEL
causalSoftmax
(
Tdata
*
data_
,
size_t
batch
,
size_t
height
,
size_t
width
,
ptrdiff_t
stride_b
,
ptrdiff_t
stride_h
)
{
Tdata
*
data
=
data_
// threadIdx.x for col_id
+
blockIdx
.
y
*
stride_b
// gridDim.y for batch_id
+
blockIdx
.
x
*
stride_h
;
// gridDim.x for row_id
INFINIOP_CUDA_KERNEL
causalSoftmax
(
Tdata
*
y_
,
const
Tdata
*
x_
,
size_t
batch
,
size_t
height
,
size_t
width
,
ptrdiff_t
y_stride_b
,
ptrdiff_t
y_stride_h
,
ptrdiff_t
x_stride_b
,
ptrdiff_t
x_stride_h
)
{
Tdata
*
y
=
y_
// threadIdx.x for col_id
+
blockIdx
.
y
*
y_stride_b
// gridDim.y for batch_id
+
blockIdx
.
x
*
y_stride_h
;
// gridDim.x for row_id
const
Tdata
*
x
=
x_
+
blockIdx
.
y
*
x_stride_b
+
blockIdx
.
x
*
x_stride_h
;
// [Reduce] Find max value in each row and store in shared memory
__shared__
Tdata
max_
;
Tdata
max_0
=
op
::
common_cuda
::
reduce_op
::
max
<
BLOCK_SIZE
,
Tdata
>
(
data
,
width
);
Tdata
max_0
=
op
::
common_cuda
::
reduce_op
::
max
<
BLOCK_SIZE
,
Tdata
>
(
x
,
width
);
if
(
threadIdx
.
x
==
0
)
{
max_
=
max_0
;
}
...
...
@@ -25,16 +31,16 @@ INFINIOP_CUDA_KERNEL causalSoftmax(Tdata *data_, size_t batch, size_t height, si
// 2 | * * * ... * * * |
// height: 3 col_id->
if
(
width
+
blockIdx
.
x
>=
threadIdx
.
x
+
height
)
{
data
[
col
]
=
exp
(
data
[
col
]
-
max_
);
y
[
col
]
=
exp
(
x
[
col
]
-
max_
);
}
else
{
data
[
col
]
=
Tdata
(
0
);
y
[
col
]
=
Tdata
(
0
);
}
}
__syncthreads
();
// [Reduce] Find the sum of each updated row and store in shared memory
__shared__
Tcompute
sum_
;
Tcompute
sum_0
=
op
::
common_cuda
::
reduce_op
::
sum
<
BLOCK_SIZE
,
Tdata
,
Tcompute
>
(
data
,
width
);
Tcompute
sum_0
=
op
::
common_cuda
::
reduce_op
::
sum
<
BLOCK_SIZE
,
Tdata
,
Tcompute
>
(
y
,
width
);
if
(
threadIdx
.
x
==
0
)
{
sum_
=
sum_0
;
}
...
...
@@ -42,7 +48,7 @@ INFINIOP_CUDA_KERNEL causalSoftmax(Tdata *data_, size_t batch, size_t height, si
// [Elementwise] Divide each element by the sum and store in shared memory
for
(
size_t
col
=
threadIdx
.
x
;
col
<
width
;
col
+=
BLOCK_SIZE
)
{
data
[
col
]
/=
Tdata
(
sum_
);
y
[
col
]
/=
Tdata
(
sum_
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment