Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
ebff82e3
Commit
ebff82e3
authored
Mar 03, 2025
by
YdrMaster
Browse files
issue/78/feat: 实现 rearrange
Signed-off-by:
YdrMaster
<
ydrml@hotmail.com
>
parent
ecb6793b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
10 deletions
+112
-10
src/utils/rearrange.cc
src/utils/rearrange.cc
+96
-0
src/utils/rearrange.h
src/utils/rearrange.h
+16
-10
No files found.
src/utils/rearrange.cc
View file @
ebff82e3
#include "rearrange.h"
#include "rearrange.h"
#include "check.h"
#include <algorithm>
#include <cstring>
#include <vector>
namespace
utils
{
void
rearrange
(
void
*
dst_
,
const
void
*
src_
,
const
size_t
*
shape
,
const
ptrdiff_t
*
dst_strides_
,
const
ptrdiff_t
*
src_strides_
,
size_t
ndim
,
size_t
element_size
)
{
struct
Dim
{
size_t
len
;
ptrdiff_t
dst
,
src
;
};
std
::
vector
<
Dim
>
dims
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
// 剔除初始的 1 长维度
if
(
shape
[
i
]
!=
1
)
{
auto
sd
=
dst_strides_
[
i
],
ss
=
src_strides_
[
i
];
// assert (sd != 0)
dims
.
push_back
(
Dim
{
shape
[
i
],
sd
,
ss
});
}
}
// 排序
std
::
sort
(
dims
.
begin
(),
dims
.
end
(),
[](
const
Dim
&
a
,
const
Dim
&
b
)
{
if
(
std
::
abs
(
a
.
dst
)
==
std
::
abs
(
b
.
dst
))
{
if
(
std
::
abs
(
a
.
src
)
==
std
::
abs
(
b
.
src
))
{
return
a
.
len
<
b
.
len
;
}
return
std
::
abs
(
a
.
src
)
>
std
::
abs
(
b
.
src
);
}
return
std
::
abs
(
a
.
dst
)
>
std
::
abs
(
b
.
dst
);
});
// # 合并连续维度
ptrdiff_t
unit
=
element_size
;
// ## 合并末尾连续维度到 unit
for
(
auto
it
=
dims
.
rbegin
();
it
!=
dims
.
rend
();
++
it
)
{
if
(
it
->
dst
==
unit
&&
it
->
src
==
unit
)
{
unit
*=
it
->
len
;
ndim
-=
1
;
}
else
{
break
;
}
}
// ## 合并任意连续维度
for
(
size_t
i
=
ndim
-
1
;
i
>
0
;
--
i
)
{
auto
&
f
=
dims
[
i
-
1
];
auto
&
b
=
dims
[
i
];
ptrdiff_t
len
=
b
.
len
;
if
(
b
.
dst
*
len
==
f
.
dst
&&
b
.
src
*
len
==
f
.
src
)
{
f
=
Dim
{
b
.
len
*
f
.
len
,
b
.
dst
,
b
.
src
};
b
=
Dim
{
1
,
0
,
0
};
ndim
-=
1
;
}
}
dims
.
resize
(
ndim
);
// 填写序号步长、输入步长和输出步长
std
::
vector
<
ptrdiff_t
>
idx_strides
(
ndim
+
1
),
dst_strides
(
ndim
),
src_strides
(
ndim
);
idx_strides
[
ndim
]
=
1
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
idx_strides
[
i
]
=
dims
[
i
].
len
;
dst_strides
[
i
]
=
dims
[
i
].
dst
;
src_strides
[
i
]
=
dims
[
i
].
src
;
}
for
(
size_t
i
=
ndim
;
i
>
0
;
--
i
)
{
idx_strides
[
i
-
1
]
*=
idx_strides
[
i
];
}
// 执行 rearrange
if
(
idx_strides
[
0
]
==
1
)
{
std
::
memcpy
(
dst_
,
src_
,
unit
);
}
else
{
for
(
size_t
i
=
0
;
i
<
idx_strides
[
0
];
++
i
)
{
auto
dst
=
reinterpret_cast
<
char
*>
(
dst_
);
auto
src
=
reinterpret_cast
<
const
char
*>
(
src_
);
for
(
size_t
j
=
0
;
j
<
ndim
;
++
j
)
{
auto
k
=
i
/
idx_strides
[
j
+
1
];
dst
+=
k
*
dst_strides
[
j
];
src
+=
k
*
src_strides
[
j
];
i
%=
idx_strides
[
j
+
1
];
}
std
::
memcpy
(
dst
,
src
,
unit
);
}
}
}
}
// namespace utils
src/utils/rearrange.h
View file @
ebff82e3
#ifndef INFINIUTILS_REARRANGE_H
#ifndef __INFINIUTILS_REARRANGE_H__
#define INFINIUTILS_REARRANGE_H
#define __INFINIUTILS_REARRANGE_H__
#include <stddef.h>
#include <stddef.h>
void
rearrange
(
void
*
dst
,
namespace
utils
{
const
void
*
src
,
const
size_t
*
shape
,
void
rearrange
(
const
ptrdiff_t
*
dst_strides
,
void
*
dst
,
const
ptrdiff_t
*
src_strides
,
const
void
*
src
,
const
size_t
ndim
,
const
size_t
*
shape
,
size_t
element_size
);
const
ptrdiff_t
*
dst_strides
,
const
ptrdiff_t
*
src_strides
,
size_t
ndim
,
size_t
element_size
);
}
// namespace utils
#endif // INFINIUTILS_REARRANGE_H
#endif //
__
INFINIUTILS_REARRANGE_H
__
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment