Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
e1caa4f5
Commit
e1caa4f5
authored
Mar 04, 2025
by
PanZezhong
Browse files
issue/78 修改rearrange
parent
12046d02
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
src/utils/rearrange.cc
src/utils/rearrange.cc
+5
-3
No files found.
src/utils/rearrange.cc
View file @
e1caa4f5
...
@@ -15,6 +15,9 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
...
@@ -15,6 +15,9 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
const
ptrdiff_t
*
src_strides_
,
const
ptrdiff_t
*
src_strides_
,
size_t
ndim
,
size_t
ndim
,
size_t
element_size
)
{
size_t
element_size
)
{
ptrdiff_t
unit
=
element_size
;
struct
Dim
{
struct
Dim
{
size_t
len
;
size_t
len
;
ptrdiff_t
dst
,
src
;
ptrdiff_t
dst
,
src
;
...
@@ -24,7 +27,7 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
...
@@ -24,7 +27,7 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
// 剔除初始的 1 长维度
// 剔除初始的 1 长维度
if
(
shape
[
i
]
!=
1
)
{
if
(
shape
[
i
]
!=
1
)
{
auto
sd
=
dst_strides_
[
i
],
ss
=
src_strides_
[
i
];
auto
sd
=
dst_strides_
[
i
]
*
unit
,
ss
=
src_strides_
[
i
]
*
unit
;
// assert (sd != 0)
// assert (sd != 0)
dims
.
push_back
(
Dim
{
shape
[
i
],
sd
,
ss
});
dims
.
push_back
(
Dim
{
shape
[
i
],
sd
,
ss
});
}
}
...
@@ -40,7 +43,6 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
...
@@ -40,7 +43,6 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
return
std
::
abs
(
a
.
dst
)
>
std
::
abs
(
b
.
dst
);
return
std
::
abs
(
a
.
dst
)
>
std
::
abs
(
b
.
dst
);
});
});
// # 合并连续维度
// # 合并连续维度
ptrdiff_t
unit
=
element_size
;
// ## 合并末尾连续维度到 unit
// ## 合并末尾连续维度到 unit
for
(
auto
it
=
dims
.
rbegin
();
it
!=
dims
.
rend
();
++
it
)
{
for
(
auto
it
=
dims
.
rbegin
();
it
!=
dims
.
rend
();
++
it
)
{
if
(
it
->
dst
==
unit
&&
it
->
src
==
unit
)
{
if
(
it
->
dst
==
unit
&&
it
->
src
==
unit
)
{
...
@@ -51,7 +53,7 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
...
@@ -51,7 +53,7 @@ std::optional<RearrangeMeta> RearrangeMeta::create(
}
}
}
}
// ## 合并任意连续维度
// ## 合并任意连续维度
for
(
size
_t
i
=
ndim
-
1
;
i
>
0
;
--
i
)
{
for
(
ptrdiff
_t
i
=
ndim
-
1
;
i
>
0
;
--
i
)
{
auto
&
f
=
dims
[
i
-
1
];
auto
&
f
=
dims
[
i
-
1
];
auto
&
b
=
dims
[
i
];
auto
&
b
=
dims
[
i
];
ptrdiff_t
len
=
b
.
len
;
ptrdiff_t
len
=
b
.
len
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment