Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
f46e9f65
"vscode:/vscode.git/clone" did not exist on "1e30fac0f14603fa75e827685126a36b7a887e93"
Commit
f46e9f65
authored
Feb 12, 2026
by
zhangyue
Browse files
issue/1008: adapt lpnorm layernorm softmax rearrange paged_attention for iluvatar
parent
bd0c922a
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
96 additions
and
22 deletions
+96
-22
scripts/python_test.py
scripts/python_test.py
+6
-3
src/infiniop/ops/layer_norm/nvidia/layer_norm_nvidia.cu
src/infiniop/ops/layer_norm/nvidia/layer_norm_nvidia.cu
+2
-0
src/infiniop/ops/layer_norm/operator.cc
src/infiniop/ops/layer_norm/operator.cc
+3
-0
src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cu
src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cu
+5
-0
src/infiniop/ops/logsoftmax/operator.cc
src/infiniop/ops/logsoftmax/operator.cc
+4
-4
src/infiniop/ops/lp_norm/nvidia/lp_norm_nvidia.cu
src/infiniop/ops/lp_norm/nvidia/lp_norm_nvidia.cu
+2
-0
src/infiniop/ops/rearrange/nvidia/rearrange_kernel.cuh
src/infiniop/ops/rearrange/nvidia/rearrange_kernel.cuh
+20
-3
src/infiniop/ops/sigmoid/operator.cc
src/infiniop/ops/sigmoid/operator.cc
+13
-4
src/infiniop/ops/softmax/nvidia/softmax_nvidia.cu
src/infiniop/ops/softmax/nvidia/softmax_nvidia.cu
+3
-0
src/infiniop/ops/topksoftmax/operator.cc
src/infiniop/ops/topksoftmax/operator.cc
+13
-1
test/infiniop/ones.py
test/infiniop/ones.py
+7
-0
test/infiniop/zeros.py
test/infiniop/zeros.py
+7
-0
xmake.lua
xmake.lua
+5
-3
xmake/iluvatar.lua
xmake/iluvatar.lua
+6
-4
No files found.
scripts/python_test.py
View file @
f46e9f65
...
@@ -20,9 +20,9 @@ def run_tests(args):
...
@@ -20,9 +20,9 @@ def run_tests(args):
#"dequantize_awq.py",
#"dequantize_awq.py",
"gelu.py"
,
"gelu.py"
,
"gemm.py"
,
"gemm.py"
,
#
"layer_norm.py",
"layer_norm.py"
,
"logsoftmax.py"
,
"logsoftmax.py"
,
#
"lp_norm.py",
"lp_norm.py"
,
"mul.py"
,
"mul.py"
,
"ones.py"
,
"ones.py"
,
"random_sample.py"
,
"random_sample.py"
,
...
@@ -31,7 +31,7 @@ def run_tests(args):
...
@@ -31,7 +31,7 @@ def run_tests(args):
"rms_norm.py"
,
"rms_norm.py"
,
"rope.py"
,
"rope.py"
,
"sigmoid.py"
,
"sigmoid.py"
,
#
"softmax.py",
"softmax.py"
,
"softplus.py"
,
"softplus.py"
,
"sub.py"
,
"sub.py"
,
"swiglu.py"
,
"swiglu.py"
,
...
@@ -39,6 +39,9 @@ def run_tests(args):
...
@@ -39,6 +39,9 @@ def run_tests(args):
"topkrouter.py"
,
"topkrouter.py"
,
"topksoftmax.py"
,
"topksoftmax.py"
,
"zeros.py"
,
"zeros.py"
,
"paged_attention.py"
,
"paged_caching.py"
,
"paged_attention_prefill.py"
]:
]:
result
=
subprocess
.
run
(
result
=
subprocess
.
run
(
f
"python
{
test
}
{
args
}
--debug"
,
text
=
True
,
encoding
=
"utf-8"
,
shell
=
True
f
"python
{
test
}
{
args
}
--debug"
,
text
=
True
,
encoding
=
"utf-8"
,
shell
=
True
...
...
src/infiniop/ops/layer_norm/nvidia/layer_norm_nvidia.cu
View file @
f46e9f65
...
@@ -255,6 +255,8 @@ infiniStatus_t Descriptor::calculate(
...
@@ -255,6 +255,8 @@ infiniStatus_t Descriptor::calculate(
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_512
)
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_512
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_4096
)
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_4096
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_2048
)
{
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_2048
)
}
else
{
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
}
...
...
src/infiniop/ops/layer_norm/operator.cc
View file @
f46e9f65
...
@@ -174,6 +174,9 @@ infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) {
...
@@ -174,6 +174,9 @@ infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) {
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/logsoftmax/nvidia/logsoftmax_nvidia.cu
View file @
f46e9f65
...
@@ -117,6 +117,11 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
...
@@ -117,6 +117,11 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
y
,
x
,
_info
.
x_dtype
,
_info
.
y_dtype
,
_info
.
batch_size
,
_info
.
probs_size
,
_info
.
ndim
,
_info
.
seq_len
,
y
,
x
,
_info
.
x_dtype
,
_info
.
y_dtype
,
_info
.
batch_size
,
_info
.
probs_size
,
_info
.
ndim
,
_info
.
seq_len
,
_info
.
y_stride_b
,
_info
.
y_stride_p
,
_info
.
x_stride_b
,
_info
.
x_stride_p
,
_info
.
y_stride_b
,
_info
.
y_stride_p
,
_info
.
x_stride_b
,
_info
.
x_stride_p
,
_info
.
y_stride_0
,
_info
.
y_stride_1
,
_info
.
x_stride_0
,
_info
.
x_stride_1
,
stream
));
_info
.
y_stride_0
,
_info
.
y_stride_1
,
_info
.
x_stride_0
,
_info
.
x_stride_1
,
stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_2048
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_2048
>
(
y
,
x
,
_info
.
x_dtype
,
_info
.
y_dtype
,
_info
.
batch_size
,
_info
.
probs_size
,
_info
.
ndim
,
_info
.
seq_len
,
_info
.
y_stride_b
,
_info
.
y_stride_p
,
_info
.
x_stride_b
,
_info
.
x_stride_p
,
_info
.
y_stride_0
,
_info
.
y_stride_1
,
_info
.
x_stride_0
,
_info
.
x_stride_1
,
stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
y
,
x
,
_info
.
x_dtype
,
_info
.
y_dtype
,
_info
.
batch_size
,
_info
.
probs_size
,
_info
.
ndim
,
_info
.
seq_len
,
y
,
x
,
_info
.
x_dtype
,
_info
.
y_dtype
,
_info
.
batch_size
,
_info
.
probs_size
,
_info
.
ndim
,
_info
.
seq_len
,
...
...
src/infiniop/ops/logsoftmax/operator.cc
View file @
f46e9f65
...
@@ -40,7 +40,7 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
...
@@ -40,7 +40,7 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_ILUVATAR_API
//
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
#ifdef ENABLE_QY_API
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
...
@@ -73,7 +73,7 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
...
@@ -73,7 +73,7 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_ILUVATAR_API
//
GET(INFINI_DEVICE_ILUVATAR, nvidia);
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
#ifdef ENABLE_QY_API
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
GET
(
INFINI_DEVICE_QY
,
nvidia
);
...
@@ -111,7 +111,7 @@ __C infiniStatus_t infiniopLogSoftmax(
...
@@ -111,7 +111,7 @@ __C infiniStatus_t infiniopLogSoftmax(
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_ILUVATAR_API
//
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
#ifdef ENABLE_QY_API
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
...
@@ -144,7 +144,7 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
...
@@ -144,7 +144,7 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
);
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_ILUVATAR_API
//
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
#ifdef ENABLE_QY_API
#ifdef ENABLE_QY_API
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
);
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
);
...
...
src/infiniop/ops/lp_norm/nvidia/lp_norm_nvidia.cu
View file @
f46e9f65
...
@@ -155,6 +155,8 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
...
@@ -155,6 +155,8 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
CALCULATE_LP_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_1024
)
CALCULATE_LP_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_1024
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CALCULATE_LP_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_512
)
CALCULATE_LP_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_512
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_2048
)
{
CALCULATE_LP_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_2048
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CALCULATE_LP_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_4096
)
CALCULATE_LP_NORM_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_4096
)
}
else
{
}
else
{
...
...
src/infiniop/ops/rearrange/nvidia/rearrange_kernel.cuh
View file @
f46e9f65
...
@@ -8,8 +8,8 @@
...
@@ -8,8 +8,8 @@
#define ARRAY_TYPE_SIZE size_t
#define ARRAY_TYPE_SIZE size_t
// 与 DEFINE_KERNELS_BY_CONSTRAINT 耦合,需要同时修改
// 与 DEFINE_KERNELS_BY_CONSTRAINT 耦合,需要同时修改
#define MAX_BLOCK_ARRAY_SIZE
5
#define MAX_BLOCK_ARRAY_SIZE
6
#define MAX_GRID_ARRAY_SIZE
5
#define MAX_GRID_ARRAY_SIZE
6
template
<
int
ArrSize
,
typename
ArrayType
>
template
<
int
ArrSize
,
typename
ArrayType
>
struct
ArrayStruct
{
struct
ArrayStruct
{
...
@@ -185,32 +185,43 @@ struct Constraint {
...
@@ -185,32 +185,43 @@ struct Constraint {
DEFINE_REARRANGE_KERNEL(double4, constraint_num, block_array_size, grid_array_size)
DEFINE_REARRANGE_KERNEL(double4, constraint_num, block_array_size, grid_array_size)
// 与 MAX_BLOCK_ARRAY_SIZE 和 MAX_GRID_ARRAY_SIZE 耦合,需要同时修改
// 与 MAX_BLOCK_ARRAY_SIZE 和 MAX_GRID_ARRAY_SIZE 耦合,需要同时修改
// 为1-
5
和1-
5
的所有组合生成内核
// 为1-
6
和1-
6
的所有组合生成内核
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
6
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
2
,
6
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
3
,
6
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
4
,
6
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
6
)
DEFINE_KERNELS_BY_CONSTRAINT
(
6
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
6
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
6
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
6
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
6
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
6
,
6
)
// 准备参数结构体
// 准备参数结构体
struct
RearrangeParams
{
struct
RearrangeParams
{
...
@@ -294,6 +305,9 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams ¶ms) {
...
@@ -294,6 +305,9 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams ¶ms) {
case 5: \
case 5: \
GET_REARRANGE_KERNEL_BY_CONSTRAINT(block_array_size, 5); \
GET_REARRANGE_KERNEL_BY_CONSTRAINT(block_array_size, 5); \
break; \
break; \
case 6: \
GET_REARRANGE_KERNEL_BY_CONSTRAINT(block_array_size, 6); \
break; \
}
}
#define GET_REARRANGE_KERNEL_BY_BLOCK_NUM \
#define GET_REARRANGE_KERNEL_BY_BLOCK_NUM \
...
@@ -313,6 +327,9 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams ¶ms) {
...
@@ -313,6 +327,9 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams ¶ms) {
case 5: \
case 5: \
GET_REARRANGE_KERNEL_BY_GRID_NUM(5); \
GET_REARRANGE_KERNEL_BY_GRID_NUM(5); \
break; \
break; \
case 6: \
GET_REARRANGE_KERNEL_BY_GRID_NUM(6); \
break; \
}
}
GET_REARRANGE_KERNEL_BY_BLOCK_NUM
GET_REARRANGE_KERNEL_BY_BLOCK_NUM
...
...
src/infiniop/ops/sigmoid/operator.cc
View file @
f46e9f65
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
#include "cpu/sigmoid_cpu.h"
#include "cpu/sigmoid_cpu.h"
#endif
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
|| defined(ENABLE_ILUVATAR_API)
#include "nvidia/sigmoid_nvidia.cuh"
#include "nvidia/sigmoid_nvidia.cuh"
#endif
#endif
...
@@ -37,6 +37,9 @@ __C infiniStatus_t infiniopCreateSigmoidDescriptor(
...
@@ -37,6 +37,9 @@ __C infiniStatus_t infiniopCreateSigmoidDescriptor(
#ifdef ENABLE_ALI_API
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -65,7 +68,9 @@ __C infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t d
...
@@ -65,7 +68,9 @@ __C infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t d
#ifdef ENABLE_ALI_API
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
@@ -101,7 +106,9 @@ __C infiniStatus_t infiniopSigmoid(
...
@@ -101,7 +106,9 @@ __C infiniStatus_t infiniopSigmoid(
#ifdef ENABLE_ALI_API
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
@@ -131,7 +138,9 @@ infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc) {
...
@@ -131,7 +138,9 @@ infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc) {
#ifdef ENABLE_ALI_API
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
...
...
src/infiniop/ops/softmax/nvidia/softmax_nvidia.cu
View file @
f46e9f65
...
@@ -128,6 +128,9 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
...
@@ -128,6 +128,9 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
y
,
x
,
_info
.
dtype
,
_info
.
othersize
,
_info
.
dimsize
,
_info
.
stride
,
stream
));
y
,
x
,
_info
.
dtype
,
_info
.
othersize
,
_info
.
dimsize
,
_info
.
stride
,
stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_2048
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_2048
>
(
y
,
x
,
_info
.
dtype
,
_info
.
othersize
,
_info
.
dimsize
,
_info
.
stride
,
stream
));
}
else
{
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
}
...
...
src/infiniop/ops/topksoftmax/operator.cc
View file @
f46e9f65
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
#include "cpu/topksoftmax_cpu.h"
#include "cpu/topksoftmax_cpu.h"
#endif
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
|| defined(ENABLE_ILUVATAR_API)
#include "nvidia/topksoftmax_nvidia.cuh"
#include "nvidia/topksoftmax_nvidia.cuh"
#endif
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_METAX_API
...
@@ -36,6 +36,9 @@ __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
...
@@ -36,6 +36,9 @@ __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
#endif
#endif
#ifdef ENABLE_ALI_API
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
}
}
...
@@ -66,6 +69,9 @@ __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescri
...
@@ -66,6 +69,9 @@ __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescri
#endif
#endif
#ifdef ENABLE_ALI_API
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
}
}
...
@@ -101,6 +107,9 @@ __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, voi
...
@@ -101,6 +107,9 @@ __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, voi
#endif
#endif
#ifdef ENABLE_ALI_API
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
}
}
...
@@ -131,6 +140,9 @@ __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescr
...
@@ -131,6 +140,9 @@ __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescr
#endif
#endif
#ifdef ENABLE_ALI_API
#ifdef ENABLE_ALI_API
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
);
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#endif
}
}
...
...
test/infiniop/ones.py
View file @
f46e9f65
...
@@ -15,6 +15,7 @@ from libinfiniop import (
...
@@ -15,6 +15,7 @@ from libinfiniop import (
InfiniDtype
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDtypeNames
,
InfiniDeviceNames
,
InfiniDeviceNames
,
InfiniDeviceEnum
,
infiniopOperatorDescriptor_t
,
infiniopOperatorDescriptor_t
,
)
)
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
...
@@ -112,6 +113,12 @@ def test(
...
@@ -112,6 +113,12 @@ def test(
dtype
=
None
,
dtype
=
None
,
sync
=
None
,
sync
=
None
,
):
):
# Skip strided cases on Iluvatar: Ones with non-contiguous tensors can hang the GPU (requires ixsmi -r to recover)
if
device
==
InfiniDeviceEnum
.
ILUVATAR
and
(
x_stride
is
not
None
or
y_stride
is
not
None
):
return
if
dtype
in
[
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
,
InfiniDtype
.
F32
,
InfiniDtype
.
F64
]:
if
dtype
in
[
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
,
InfiniDtype
.
F32
,
InfiniDtype
.
F64
]:
x
=
TestTensor
(
shape
,
x_stride
,
dtype
,
device
)
x
=
TestTensor
(
shape
,
x_stride
,
dtype
,
device
)
elif
dtype
in
[
InfiniDtype
.
BYTE
,
InfiniDtype
.
U8
,
InfiniDtype
.
U16
,
InfiniDtype
.
U32
,
InfiniDtype
.
U64
,
elif
dtype
in
[
InfiniDtype
.
BYTE
,
InfiniDtype
.
U8
,
InfiniDtype
.
U16
,
InfiniDtype
.
U32
,
InfiniDtype
.
U64
,
...
...
test/infiniop/zeros.py
View file @
f46e9f65
...
@@ -15,6 +15,7 @@ from libinfiniop import (
...
@@ -15,6 +15,7 @@ from libinfiniop import (
InfiniDtype
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDtypeNames
,
InfiniDeviceNames
,
InfiniDeviceNames
,
InfiniDeviceEnum
,
infiniopOperatorDescriptor_t
,
infiniopOperatorDescriptor_t
,
)
)
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
...
@@ -114,6 +115,12 @@ def test(
...
@@ -114,6 +115,12 @@ def test(
dtype
=
None
,
dtype
=
None
,
sync
=
None
,
sync
=
None
,
):
):
# Skip strided cases on Iluvatar: Zeros with non-contiguous tensors can hang the GPU (requires ixsmi -r to recover)
if
device
==
InfiniDeviceEnum
.
ILUVATAR
and
(
x_stride
is
not
None
or
y_stride
is
not
None
):
return
if
dtype
in
[
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
,
InfiniDtype
.
F32
,
InfiniDtype
.
F64
]:
if
dtype
in
[
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
,
InfiniDtype
.
F32
,
InfiniDtype
.
F64
]:
x
=
TestTensor
(
shape
,
x_stride
,
dtype
,
device
)
x
=
TestTensor
(
shape
,
x_stride
,
dtype
,
device
)
elif
dtype
in
[
InfiniDtype
.
BYTE
,
InfiniDtype
.
U8
,
InfiniDtype
.
U16
,
InfiniDtype
.
U32
,
InfiniDtype
.
U64
,
elif
dtype
in
[
InfiniDtype
.
BYTE
,
InfiniDtype
.
U8
,
InfiniDtype
.
U16
,
InfiniDtype
.
U32
,
InfiniDtype
.
U64
,
...
...
xmake.lua
View file @
f46e9f65
...
@@ -115,10 +115,12 @@ option("iluvatar-gpu")
...
@@ -115,10 +115,12 @@ option("iluvatar-gpu")
set_description
(
"Whether to compile implementations for Iluvatar GPU"
)
set_description
(
"Whether to compile implementations for Iluvatar GPU"
)
option_end
()
option_end
()
option
(
"i
vcore-20
"
)
option
(
"i
luvatar_arch
"
)
set_default
(
false
)
set_default
(
"ivcore20"
)
set_showmenu
(
true
)
set_showmenu
(
true
)
set_description
(
"Use ivcore20"
)
set_description
(
"Set Iluvatar GPU architecture (e.g. ivcore20)"
)
set_values
(
"ivcore20"
)
set_category
(
"option"
)
option_end
()
option_end
()
if
has_config
(
"iluvatar-gpu"
)
then
if
has_config
(
"iluvatar-gpu"
)
then
...
...
xmake/iluvatar.lua
View file @
f46e9f65
toolchain
(
"iluvatar.toolchain"
)
local
iluvatar_arch
=
get_config
(
"iluvatar_arch"
)
or
"ivcore20"
toolchain
(
"iluvatar.toolchain"
)
set_toolset
(
"cc"
,
"clang"
)
set_toolset
(
"cc"
,
"clang"
)
set_toolset
(
"cxx"
,
"clang++"
)
set_toolset
(
"cxx"
,
"clang++"
)
set_toolset
(
"cu"
,
"clang++"
)
set_toolset
(
"cu"
,
"clang++"
)
...
@@ -44,9 +46,7 @@ target("infiniop-iluvatar")
...
@@ -44,9 +46,7 @@ target("infiniop-iluvatar")
set_warnings
(
"all"
,
"error"
)
set_warnings
(
"all"
,
"error"
)
add_cuflags
(
"-Wno-error=unused-private-field"
,
"-Wno-error=unused-variable"
,
"-Wno-unused-variable"
)
add_cuflags
(
"-Wno-error=unused-private-field"
,
"-Wno-error=unused-variable"
,
"-Wno-unused-variable"
)
add_cuflags
(
"-fPIC"
,
"-x"
,
"ivcore"
,
"-std=c++17"
,
{
force
=
true
})
add_cuflags
(
"-fPIC"
,
"-x"
,
"ivcore"
,
"-std=c++17"
,
{
force
=
true
})
if
has_config
(
"ivcore-20"
)
then
add_cuflags
(
"--cuda-gpu-arch="
..
iluvatar_arch
,
{
force
=
true
})
add_cuflags
(
"--cuda-gpu-arch=ivcore20"
,
{
force
=
true
})
end
add_culdflags
(
"-fPIC"
)
add_culdflags
(
"-fPIC"
)
add_cxflags
(
"-fPIC"
,
"-Wno-error=unused-variable"
,
"-Wno-unused-variable"
)
add_cxflags
(
"-fPIC"
,
"-Wno-error=unused-variable"
,
"-Wno-unused-variable"
)
add_cxxflags
(
"-fPIC"
,
"-Wno-error=unused-variable"
,
"-Wno-unused-variable"
)
add_cxxflags
(
"-fPIC"
,
"-Wno-error=unused-variable"
,
"-Wno-unused-variable"
)
...
@@ -75,6 +75,7 @@ target("infinirt-iluvatar")
...
@@ -75,6 +75,7 @@ target("infinirt-iluvatar")
set_warnings
(
"all"
,
"error"
)
set_warnings
(
"all"
,
"error"
)
add_cuflags
(
"-fPIC"
,
"-x"
,
"ivcore"
,
"-std=c++17"
,
{
force
=
true
})
add_cuflags
(
"-fPIC"
,
"-x"
,
"ivcore"
,
"-std=c++17"
,
{
force
=
true
})
add_cuflags
(
"--cuda-gpu-arch="
..
iluvatar_arch
,
{
force
=
true
})
add_culdflags
(
"-fPIC"
)
add_culdflags
(
"-fPIC"
)
add_cxflags
(
"-fPIC"
)
add_cxflags
(
"-fPIC"
)
add_cxxflags
(
"-fPIC"
)
add_cxxflags
(
"-fPIC"
)
...
@@ -97,6 +98,7 @@ target("infiniccl-iluvatar")
...
@@ -97,6 +98,7 @@ target("infiniccl-iluvatar")
set_warnings
(
"all"
,
"error"
)
set_warnings
(
"all"
,
"error"
)
add_cuflags
(
"-fPIC"
,
"-x"
,
"ivcore"
,
"-std=c++17"
,
{
force
=
true
})
add_cuflags
(
"-fPIC"
,
"-x"
,
"ivcore"
,
"-std=c++17"
,
{
force
=
true
})
add_cuflags
(
"--cuda-gpu-arch="
..
iluvatar_arch
,
{
force
=
true
})
add_culdflags
(
"-fPIC"
)
add_culdflags
(
"-fPIC"
)
add_cxflags
(
"-fPIC"
)
add_cxflags
(
"-fPIC"
)
add_cxxflags
(
"-fPIC"
)
add_cxxflags
(
"-fPIC"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment