Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
3138cbdd
Commit
3138cbdd
authored
Dec 22, 2025
by
wooway777
Browse files
issue/824 - removed redundant batch handling
parent
215d1932
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
114 additions
and
28 deletions
+114
-28
python/infinicore/nn/functional/rope.py
python/infinicore/nn/functional/rope.py
+2
-13
test/infinicore/ops/rope.py
test/infinicore/ops/rope.py
+112
-15
No files found.
python/infinicore/nn/functional/rope.py
View file @
3138cbdd
...
...
@@ -20,16 +20,6 @@ def rope(
)
->
Tensor
:
r
"""Rotary Position Embedding(RoPE)."""
bs
,
seq_len
,
num_heads
,
head_dim
=
x
.
shape
x_stride
=
x
.
stride
()
assert
seq_len
*
x_stride
[
1
]
==
x_stride
[
0
],
(
"x need to be continuous in dim=0 and dim=1"
)
x
=
x
.
view
((
bs
*
seq_len
,
num_heads
,
head_dim
))
bs
,
num
=
pos_ids
.
shape
pos_ids
=
pos_ids
.
view
((
bs
*
num
,))
if
out
is
None
:
return
Tensor
(
_infinicore
.
rope
(
...
...
@@ -39,9 +29,8 @@ def rope(
cos_table
.
_underlying
,
algo
,
)
)
.
view
((
bs
,
seq_len
,
num_heads
,
head_dim
))
)
out
=
out
.
view
((
bs
*
seq_len
,
num_heads
,
head_dim
))
_infinicore
.
rope_
(
out
.
_underlying
,
x
.
_underlying
,
...
...
@@ -50,4 +39,4 @@ def rope(
cos_table
.
_underlying
,
algo
,
)
return
out
.
view
((
bs
,
seq_len
,
num_heads
,
head_dim
))
return
out
test/infinicore/ops/rope.py
View file @
3138cbdd
...
...
@@ -22,11 +22,85 @@ import infinicore
_TEST_CASES_DATA
=
[
# bs, seq_len, num, head_dim, Algo
(
1
,
1
,
1
,
64
,
RopeAlgo
.
GPT_NEOX
),
(
1
,
5
,
32
,
64
,
RopeAlgo
.
GPT_NEOX
),
(
1
,
1
,
1
,
128
,
RopeAlgo
.
GPT_J
),
(
1
,
10
,
1
,
64
,
RopeAlgo
.
GPT_J
),
# bs, seq_len, num, head_dim, src strides, dst strides, Algo
(
1
,
1
,
1
,
64
,
None
,
None
,
RopeAlgo
.
GPT_NEOX
),
(
1
,
5
,
32
,
64
,
None
,
None
,
RopeAlgo
.
GPT_NEOX
),
(
1
,
1
,
1
,
128
,
None
,
None
,
RopeAlgo
.
GPT_J
),
(
1
,
10
,
1
,
64
,
None
,
None
,
RopeAlgo
.
GPT_J
),
(
2
,
20
,
16
,
128
,
None
,
None
,
RopeAlgo
.
GPT_NEOX
),
(
4
,
50
,
32
,
256
,
None
,
None
,
RopeAlgo
.
GPT_J
),
(
2
,
20
,
16
,
128
,
(
20
*
16
*
128
*
16
,
16
*
128
*
4
,
128
*
2
,
1
),
(
20
*
16
*
128
*
16
,
16
*
128
*
4
,
128
*
2
,
1
),
RopeAlgo
.
GPT_NEOX
,
),
(
2
,
20
,
16
,
128
,
(
20
*
16
*
128
*
16
,
16
*
128
*
4
,
128
*
2
,
1
),
(
20
*
16
*
128
*
16
,
16
*
128
*
4
,
128
*
2
,
1
),
RopeAlgo
.
GPT_J
,
),
(
4
,
50
,
32
,
256
,
(
50
*
32
*
256
*
16
,
32
*
256
*
4
,
256
*
2
,
1
),
(
50
*
32
*
256
*
36
,
32
*
256
*
6
,
256
*
3
,
1
),
RopeAlgo
.
GPT_NEOX
,
),
(
4
,
50
,
32
,
256
,
(
50
*
32
*
256
*
16
,
32
*
256
*
4
,
256
*
2
,
1
),
(
50
*
32
*
256
*
36
,
32
*
256
*
6
,
256
*
3
,
1
),
RopeAlgo
.
GPT_J
,
),
(
32
,
64
,
8
,
128
,
(
64
*
8
*
128
*
16
,
8
*
128
*
4
,
128
*
2
,
1
),
(
64
*
8
*
128
*
16
,
8
*
128
*
4
,
128
*
2
,
1
),
RopeAlgo
.
GPT_NEOX
,
),
(
32
,
64
,
8
,
128
,
(
64
*
8
*
128
*
16
,
8
*
128
*
4
,
128
*
2
,
1
),
(
64
*
8
*
128
*
16
,
8
*
128
*
4
,
128
*
2
,
1
),
RopeAlgo
.
GPT_J
,
),
(
64
,
128
,
32
,
64
,
(
128
*
32
*
64
*
16
,
32
*
64
*
4
,
64
*
2
,
1
),
(
128
*
32
*
64
*
36
,
32
*
64
*
6
,
64
*
3
,
1
),
RopeAlgo
.
GPT_NEOX
,
),
(
64
,
128
,
32
,
64
,
(
128
*
32
*
64
*
16
,
32
*
64
*
4
,
64
*
2
,
1
),
(
128
*
32
*
64
*
36
,
32
*
64
*
6
,
64
*
3
,
1
),
RopeAlgo
.
GPT_J
,
),
]
# Tolerance configuration
...
...
@@ -49,7 +123,8 @@ def parse_test_cases():
for
data
in
_TEST_CASES_DATA
:
bs
,
seq_len
,
num
,
head_dim
=
data
[
0
],
data
[
1
],
data
[
2
],
data
[
3
]
algo
=
data
[
4
]
src_strides
,
dst_strides
=
data
[
4
],
data
[
5
]
algo
=
data
[
6
]
# Determine shapes based on batch dimension
out_shape
=
(
bs
,
seq_len
,
num
,
head_dim
)
...
...
@@ -58,15 +133,16 @@ def parse_test_cases():
cos_table_shape
=
(
seq_len
,
head_dim
//
2
)
# Check if tensors support in-place operations
c_supports_inplace
=
not
is_broadcast
(
out_shape
)
# x tensor supports in-place if it's not a broadcasted tensor
x_supports_inplace
=
not
is_broadcast
(
src_strides
)
# Generate test cases for all data types
for
dtype
in
_TENSOR_DTYPES
:
tolerance
=
_TOLERANCE_MAP
.
get
(
dtype
,
{
"atol"
:
0
,
"rtol"
:
1e-3
})
# Create typed tensor specs
out_spec
=
TensorSpec
.
from_tensor
(
out_shape
,
None
,
dtype
)
x_spec
=
TensorSpec
.
from_tensor
(
x_shape
,
None
,
dtype
)
out_spec
=
TensorSpec
.
from_tensor
(
out_shape
,
dst_strides
,
dtype
)
x_spec
=
TensorSpec
.
from_tensor
(
x_shape
,
src_strides
,
dtype
)
sin_table_spec
=
TensorSpec
.
from_tensor
(
sin_table_shape
,
None
,
dtype
)
cos_table_spec
=
TensorSpec
.
from_tensor
(
cos_table_shape
,
None
,
dtype
)
...
...
@@ -83,7 +159,7 @@ def parse_test_cases():
)
# Test Case 2: In-place with explicit output tensor
if
c_supports_inplace
:
if
dst_strides
is
None
or
not
is_broadcast
(
dst_strides
)
:
test_cases
.
append
(
TestCase
(
inputs
=
[
x_spec
,
sin_table_spec
,
cos_table_spec
],
...
...
@@ -95,6 +171,19 @@ def parse_test_cases():
)
)
# Test Case 3: In-place on input tensor (x)
if
x_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
x_spec
,
sin_table_spec
,
cos_table_spec
],
kwargs
=
{
"algo"
:
algo
,
"out"
:
0
},
# Use index 0 for first input
output_spec
=
None
,
comparison_target
=
0
,
# Compare first input (x tensor)
tolerance
=
tolerance
,
description
=
f
"Rope - INPLACE(x)"
,
)
)
return
test_cases
...
...
@@ -107,6 +196,13 @@ def rotary_embedding(t, sin, cos, algo, *, out=None):
return
t_out_1
,
t_out_2
# If out parameter is provided and it's the same as input t, operate in-place
if
out
is
not
None
:
if
out
.
data_ptr
()
==
t
.
data_ptr
():
ans
=
t
# Use the same tensor for in-place operation
else
:
ans
=
out
# Use provided output tensor
else
:
ans
=
t
.
clone
()
dh
=
t
.
shape
[
-
1
]
...
...
@@ -114,8 +210,8 @@ def rotary_embedding(t, sin, cos, algo, *, out=None):
assert
dh
%
2
==
0
,
"Embedding dimension must be even."
if
RopeAlgo
.
GPT_J
==
algo
:
t_even
=
t
[...,
0
::
2
]
# [seq_len, n_head, dh // 2]
t_odd
=
t
[...,
1
::
2
]
# [seq_len, n_head, dh // 2]
t_even
=
t
[...,
0
::
2
]
# [
bs,
seq_len, n_head, dh // 2]
t_odd
=
t
[...,
1
::
2
]
# [
bs,
seq_len, n_head, dh // 2]
t_out_even
,
t_out_odd
=
_torch_rope
(
sin
,
cos
,
t_even
,
t_odd
)
...
...
@@ -131,9 +227,10 @@ def rotary_embedding(t, sin, cos, algo, *, out=None):
ans
[...,
:
half_dim
]
=
t_out_first
.
to
(
dt
)
ans
[...,
half_dim
:]
=
t_out_second
.
to
(
dt
)
else
:
raise
KeyError
(
"
error Algo
"
)
raise
KeyError
(
"
Unsupported RoPE algorithm
"
)
if
out
is
not
None
:
# If operating in-place on t, we don't need to copy back
if
out
is
not
None
and
out
.
data_ptr
()
!=
t
.
data_ptr
():
out
.
copy_
(
ans
)
return
out
return
ans
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment