Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
3bdd832e
Commit
3bdd832e
authored
Sep 17, 2025
by
zhangyue
Browse files
issue/436: 支持9g7b 4b模型
parent
6892a7f5
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
72 additions
and
18 deletions
+72
-18
src/infiniop/devices/kunlun/kunlun_kernel_common.h
src/infiniop/devices/kunlun/kunlun_kernel_common.h
+1
-1
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
+2
-0
src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
...nfiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
+6
-6
src/infiniop/ops/rope/kunlun/rope_kunlun.xpu
src/infiniop/ops/rope/kunlun/rope_kunlun.xpu
+8
-8
test/infiniop/attention.py
test/infiniop/attention.py
+1
-1
test/infiniop/libinfiniop/utils.py
test/infiniop/libinfiniop/utils.py
+51
-1
test/infiniop/rearrange.py
test/infiniop/rearrange.py
+2
-1
test/infiniop/rms_norm.py
test/infiniop/rms_norm.py
+1
-0
No files found.
src/infiniop/devices/kunlun/kunlun_kernel_common.h
View file @
3bdd832e
...
...
@@ -12,7 +12,7 @@
namespace
device
::
kunlun
::
kernel
{
#define SM_SIZE
1024
0
#define SM_SIZE
4096
0
/**
* @brief Define ptrdiff_t and size_t for kunlun xpu
...
...
src/infiniop/ops/gemm/kunlun/gemm_kunlun.cc
View file @
3bdd832e
...
...
@@ -102,6 +102,8 @@ infiniStatus_t Descriptor::calculate(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
return
INFINI_STATUS_SUCCESS
;
}));
xpu_wait
(
stream
);
return
INFINI_STATUS_SUCCESS
;
}
...
...
src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu
View file @
3bdd832e
...
...
@@ -120,13 +120,13 @@ Descriptor::calculate(
switch (_info.dt_p) {
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int32_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int32_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int32_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
...
...
@@ -135,13 +135,13 @@ Descriptor::calculate(
switch (_info.dt_p) {
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int64_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int64_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int64_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
...
...
src/infiniop/ops/rope/kunlun/rope_kunlun.xpu
View file @
3bdd832e
...
...
@@ -159,13 +159,13 @@ infiniStatus_t Descriptor::calculate(
switch (_info.data_type) {
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int32_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int32_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int32_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
...
...
@@ -173,20 +173,20 @@ infiniStatus_t Descriptor::calculate(
switch (_info.data_type) {
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, uint32_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, uint32_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, uint32_t);
re
turn INFINI_STATUS_SUCCESS
;
b
re
ak
;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
else {
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rope::kunlun
...
...
test/infiniop/attention.py
View file @
3bdd832e
...
...
@@ -2,6 +2,7 @@ from ctypes import c_uint64
import
ctypes
import
sys
import
os
import
torch
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
".."
)))
from
libinfiniop
import
(
...
...
@@ -21,7 +22,6 @@ from libinfiniop import (
infiniopOperatorDescriptor_t
,
)
import
torch
def
causal_softmax
(
x
):
...
...
test/infiniop/libinfiniop/utils.py
View file @
3bdd832e
from
typing
import
Sequence
import
torch
import
ctypes
import
numpy
as
np
from
.datatypes
import
*
from
.devices
import
*
from
.liboperators
import
infiniopTensorDescriptor_t
,
LIBINFINIOP
,
infiniopHandle_t
...
...
@@ -87,6 +88,12 @@ class TestTensor(CTensor):
self
.
_torch_tensor
=
set_tensor
.
to
(
to_torch_dtype
(
dt
)).
to
(
torch_device_map
[
device
]
)
elif
mode
==
"binary"
:
assert
set_tensor
is
not
None
assert
torch_shape
==
list
(
set_tensor
.
shape
)
self
.
_torch_tensor
=
set_tensor
.
to
(
to_torch_dtype
(
dt
)).
to
(
torch_device_map
[
device
]
)
else
:
raise
ValueError
(
"Unsupported mode"
)
...
...
@@ -95,7 +102,7 @@ class TestTensor(CTensor):
if
bias
is
not
None
:
self
.
_torch_tensor
+=
bias
if
strides
is
not
None
:
if
strides
is
not
None
and
mode
!=
"binary"
:
self
.
_data_tensor
=
rearrange_tensor
(
self
.
_torch_tensor
,
torch_strides
)
else
:
self
.
_data_tensor
=
self
.
_torch_tensor
.
clone
()
...
...
@@ -114,6 +121,14 @@ class TestTensor(CTensor):
def
is_broadcast
(
self
):
return
self
.
strides
is
not
None
and
0
in
self
.
strides
@
staticmethod
def
from_binary
(
binary_file
,
shape
,
strides
,
dt
:
InfiniDtype
,
device
:
InfiniDeviceEnum
):
data
=
np
.
fromfile
(
binary_file
,
dtype
=
to_numpy_dtype
(
dt
))
base
=
torch
.
from_numpy
(
data
)
torch_tensor
=
torch
.
as_strided
(
base
,
size
=
shape
,
stride
=
strides
).
to
(
torch_device_map
[
device
])
return
TestTensor
(
shape
,
strides
,
dt
,
device
,
mode
=
"binary"
,
set_tensor
=
torch_tensor
)
@
staticmethod
def
from_torch
(
torch_tensor
,
dt
:
InfiniDtype
,
device
:
InfiniDeviceEnum
):
shape_
=
list
(
torch_tensor
.
shape
)
...
...
@@ -154,6 +169,38 @@ def to_torch_dtype(dt: InfiniDtype, compatability_mode=False):
raise
ValueError
(
"Unsupported data type"
)
def
to_numpy_dtype
(
dt
:
InfiniDtype
,
compatability_mode
=
False
):
if
dt
==
InfiniDtype
.
I8
:
return
np
.
int8
elif
dt
==
InfiniDtype
.
I16
:
return
np
.
int16
elif
dt
==
InfiniDtype
.
I32
:
return
np
.
int32
elif
dt
==
InfiniDtype
.
I64
:
return
np
.
int64
elif
dt
==
InfiniDtype
.
U8
:
return
np
.
uint8
elif
dt
==
InfiniDtype
.
U16
:
return
np
.
uint16
if
not
compatability_mode
else
np
.
int16
elif
dt
==
InfiniDtype
.
U32
:
return
np
.
uint32
if
not
compatability_mode
else
np
.
int32
elif
dt
==
InfiniDtype
.
U64
:
return
np
.
uint64
if
not
compatability_mode
else
np
.
int64
elif
dt
==
InfiniDtype
.
F16
:
return
np
.
float16
elif
dt
==
InfiniDtype
.
BF16
:
# numpy 1.20+ 有 float32 的模拟 bf16 方案: np.dtype("bfloat16")
# 但很多环境里没直接支持,通常要 fallback 到 float32
return
np
.
dtype
(
"bfloat16"
)
if
not
compatability_mode
else
np
.
float32
elif
dt
==
InfiniDtype
.
F32
:
return
np
.
float32
elif
dt
==
InfiniDtype
.
F64
:
return
np
.
float64
else
:
raise
ValueError
(
"Unsupported data type"
)
class
TestWorkspace
:
def
__init__
(
self
,
size
,
device
):
if
size
!=
0
:
...
...
@@ -422,6 +469,9 @@ def print_discrepancy(
is_terminal
=
sys
.
stdout
.
isatty
()
actual
=
actual
.
to
(
"cpu"
)
expected
=
expected
.
to
(
"cpu"
)
actual_isnan
=
torch
.
isnan
(
actual
)
expected_isnan
=
torch
.
isnan
(
expected
)
...
...
test/infiniop/rearrange.py
View file @
3bdd832e
...
...
@@ -75,6 +75,7 @@ _TEST_CASES = [
row_major_strides
((
3
,
4
,
50
,
50
,
5
,
7
)),
# x_stride
column_major_strides
((
3
,
4
,
50
,
50
,
5
,
7
)),
# y_stride
),
((
15
,
10752
),
(
0
,
1
),
(
10752
,
1
)),
]
# Data types used for testing
...
...
@@ -94,7 +95,7 @@ NUM_ITERATIONS = 1000
def
rearrange_torch
(
y
,
x
,
x_shape
,
y_stride
):
y
.
set_
(
y
.
untyped_storage
(),
0
,
x_shape
,
y_stride
)
y
[:]
=
x
.
view
_as
(
y
)
y
.
copy_
(
x
.
expand
_as
(
y
)
)
def
test
(
...
...
test/infiniop/rms_norm.py
View file @
3bdd832e
...
...
@@ -30,6 +30,7 @@ _TEST_CASES_ = [
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
(
12
,
8
,
1
),
(
12
,
8
,
1
)),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
)),
((
15
,
3584
),
(
15
,
3584
),
(
3584
,),
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
)),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
)),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment