Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
ba9b0d83
Commit
ba9b0d83
authored
Mar 04, 2025
by
YdrMaster
Browse files
issue/78/feat: 分离 rearrange 的规划和执行以供算子复用
Signed-off-by:
YdrMaster
<
ydrml@hotmail.com
>
parent
3aad9f8f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
111 deletions
+73
-111
src/utils/rearrange.cc
src/utils/rearrange.cc
+48
-111
src/utils/rearrange.h
src/utils/rearrange.h
+25
-0
No files found.
src/utils/rearrange.cc
View file @
ba9b0d83
...
...
@@ -6,15 +6,15 @@
namespace
utils
{
void
rearrange
(
void
*
dst_
,
const
void
*
src_
,
RearrangeMeta
::
RearrangeMeta
(
std
::
vector
<
ptrdiff_t
>
meta
)
:
_meta
(
std
::
move
(
meta
))
{}
std
::
optional
<
RearrangeMeta
>
RearrangeMeta
::
create
(
const
size_t
*
shape
,
const
ptrdiff_t
*
dst_strides_
,
const
ptrdiff_t
*
src_strides_
,
size_t
ndim
,
size_t
element_size
)
{
struct
Dim
{
size_t
len
;
ptrdiff_t
dst
,
src
;
...
...
@@ -63,131 +63,68 @@ void rearrange(
}
dims
.
resize
(
ndim
);
// 填写序号步长、输入步长和输出步长
std
::
vector
<
ptrdiff_t
>
idx_strides
(
ndim
+
1
),
dst_strides
(
ndim
),
src_strides
(
ndim
);
idx_strides
[
ndim
]
=
1
;
std
::
vector
<
ptrdiff_t
>
meta
(
2
+
ndim
*
3
);
meta
[
0
]
=
unit
;
meta
[
1
+
ndim
]
=
1
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
idx_strides
[
i
]
=
dims
[
i
].
len
;
dst_strides
[
i
]
=
dims
[
i
].
dst
;
src_strides
[
i
]
=
dims
[
i
].
src
;
meta
[
1
+
i
]
=
dims
[
i
].
len
;
meta
[
1
+
1
+
ndim
+
i
]
=
dims
[
i
].
dst
;
meta
[
1
+
1
+
ndim
*
2
+
i
]
=
dims
[
i
].
src
;
}
for
(
size
_t
i
=
ndim
;
i
>
0
;
--
i
)
{
idx_strides
[
i
-
1
]
*=
idx_strides
[
i
];
for
(
ptrdiff
_t
i
=
ndim
;
i
>
0
;
--
i
)
{
meta
[
1
+
i
-
1
]
*=
meta
[
1
+
i
];
}
return
{
RearrangeMeta
(
std
::
move
(
meta
))};
}
size_t
RearrangeMeta
::
ndim
()
const
{
return
(
_meta
.
size
()
-
2
)
/
3
;
}
size_t
RearrangeMeta
::
unit
()
const
{
return
_meta
[
0
];
}
size_t
RearrangeMeta
::
count
()
const
{
return
_meta
[
1
];
}
const
ptrdiff_t
*
RearrangeMeta
::
idx_strides
()
const
{
return
_meta
.
data
()
+
2
;
}
const
ptrdiff_t
*
RearrangeMeta
::
dst_strides
()
const
{
return
idx_strides
()
+
ndim
();
}
const
ptrdiff_t
*
RearrangeMeta
::
src_strides
()
const
{
return
dst_strides
()
+
ndim
();
}
void
RearrangeMeta
::
launch
(
void
*
dst_
,
const
void
*
src_
)
const
{
auto
const
ndim_
=
ndim
();
auto
const
count_
=
count
();
auto
const
unit_
=
unit
();
auto
const
idx_strides_
=
idx_strides
();
auto
const
dst_strides_
=
dst_strides
();
auto
const
src_strides_
=
src_strides
();
// 执行 rearrange
if
(
idx_strides
[
0
]
==
1
)
{
std
::
memcpy
(
dst_
,
src_
,
unit
);
if
(
count_
==
1
)
{
std
::
memcpy
(
dst_
,
src_
,
unit
_
);
}
else
{
for
(
size_t
i
=
0
;
i
<
idx_strides
[
0
];
++
i
)
{
for
(
size_t
i
=
0
;
i
<
idx_strides
_
[
0
];
++
i
)
{
auto
dst
=
reinterpret_cast
<
char
*>
(
dst_
);
auto
src
=
reinterpret_cast
<
const
char
*>
(
src_
);
for
(
size_t
j
=
0
;
j
<
ndim
;
++
j
)
{
auto
k
=
i
/
idx_strides
[
j
+
1
];
dst
+=
k
*
dst_strides
[
j
];
src
+=
k
*
src_strides
[
j
];
i
%=
idx_strides
[
j
+
1
];
auto
rem
=
i
;
for
(
size_t
j
=
0
;
j
<
ndim_
;
++
j
)
{
auto
k
=
rem
/
idx_strides_
[
j
+
1
];
dst
+=
k
*
dst_strides_
[
j
];
src
+=
k
*
src_strides_
[
j
];
rem
%=
idx_strides_
[
j
+
1
];
}
std
::
memcpy
(
dst
,
src
,
unit
);
std
::
memcpy
(
dst
,
src
,
unit
_
);
}
}
}
}
// namespace utils
#include "check.h"
#include <algorithm>
#include <cstring>
#include <vector>
namespace
utils
{
void
rearrange
(
void
*
dst
_
,
const
void
*
src
_
,
void
*
dst
,
const
void
*
src
,
const
size_t
*
shape
,
const
ptrdiff_t
*
dst_strides
_
,
const
ptrdiff_t
*
src_strides
_
,
const
ptrdiff_t
*
dst_strides
,
const
ptrdiff_t
*
src_strides
,
size_t
ndim
,
size_t
element_size
)
{
struct
Dim
{
size_t
len
;
ptrdiff_t
dst
,
src
;
};
ptrdiff_t
unit
=
element_size
;
std
::
vector
<
Dim
>
dims
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
// 剔除初始的 1 长维度
if
(
shape
[
i
]
!=
1
)
{
auto
sd
=
dst_strides_
[
i
]
*
unit
,
ss
=
src_strides_
[
i
]
*
unit
;
// assert (sd != 0)
dims
.
push_back
(
Dim
{
shape
[
i
],
sd
,
ss
});
}
}
// 排序
std
::
sort
(
dims
.
begin
(),
dims
.
end
(),
[](
const
Dim
&
a
,
const
Dim
&
b
)
{
if
(
std
::
abs
(
a
.
dst
)
==
std
::
abs
(
b
.
dst
))
{
if
(
std
::
abs
(
a
.
src
)
==
std
::
abs
(
b
.
src
))
{
return
a
.
len
<
b
.
len
;
}
return
std
::
abs
(
a
.
src
)
>
std
::
abs
(
b
.
src
);
}
return
std
::
abs
(
a
.
dst
)
>
std
::
abs
(
b
.
dst
);
});
// # 合并连续维度
// ## 合并末尾连续维度到 unit
for
(
auto
it
=
dims
.
rbegin
();
it
!=
dims
.
rend
();
++
it
)
{
if
(
it
->
dst
==
unit
&&
it
->
src
==
unit
)
{
unit
*=
it
->
len
;
ndim
-=
1
;
auto
scheme
=
RearrangeMeta
::
create
(
shape
,
dst_strides
,
src_strides
,
ndim
,
element_size
);
if
(
scheme
)
{
scheme
->
launch
(
dst
,
src
);
}
else
{
break
;
}
}
// ## 合并任意连续维度
for
(
ptrdiff_t
i
=
ndim
-
1
;
i
>
0
;
--
i
)
{
auto
&
f
=
dims
[
i
-
1
];
auto
&
b
=
dims
[
i
];
ptrdiff_t
len
=
b
.
len
;
if
(
b
.
dst
*
len
==
f
.
dst
&&
b
.
src
*
len
==
f
.
src
)
{
f
=
Dim
{
b
.
len
*
f
.
len
,
b
.
dst
,
b
.
src
};
b
=
Dim
{
1
,
0
,
0
};
ndim
-=
1
;
}
}
dims
.
resize
(
ndim
);
// 填写序号步长、输入步长和输出步长
std
::
vector
<
ptrdiff_t
>
idx_strides
(
ndim
+
1
),
dst_strides
(
ndim
),
src_strides
(
ndim
);
idx_strides
[
ndim
]
=
1
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
idx_strides
[
i
]
=
dims
[
i
].
len
;
dst_strides
[
i
]
=
dims
[
i
].
dst
;
src_strides
[
i
]
=
dims
[
i
].
src
;
}
for
(
ptrdiff_t
i
=
ndim
;
i
>
0
;
--
i
)
{
idx_strides
[
i
-
1
]
*=
idx_strides
[
i
];
}
// 执行 rearrange
if
(
idx_strides
[
0
]
==
1
)
{
std
::
memcpy
(
dst_
,
src_
,
unit
);
}
else
{
for
(
size_t
i
=
0
;
i
<
idx_strides
[
0
];
++
i
)
{
auto
dst
=
reinterpret_cast
<
char
*>
(
dst_
);
auto
src
=
reinterpret_cast
<
const
char
*>
(
src_
);
auto
rem
=
i
;
for
(
size_t
j
=
0
;
j
<
ndim
;
++
j
)
{
auto
k
=
rem
/
idx_strides
[
j
+
1
];
dst
+=
k
*
dst_strides
[
j
];
src
+=
k
*
src_strides
[
j
];
rem
%=
idx_strides
[
j
+
1
];
}
std
::
memcpy
(
dst
,
src
,
unit
);
}
std
::
abort
();
}
}
...
...
src/utils/rearrange.h
View file @
ba9b0d83
#ifndef __INFINIUTILS_REARRANGE_H__
#define __INFINIUTILS_REARRANGE_H__
#include <optional>
#include <stddef.h>
#include <vector>
namespace
utils
{
class
RearrangeMeta
{
std
::
vector
<
ptrdiff_t
>
_meta
;
RearrangeMeta
(
std
::
vector
<
ptrdiff_t
>
);
public:
static
std
::
optional
<
RearrangeMeta
>
create
(
const
size_t
*
shape
,
const
ptrdiff_t
*
dst_strides
,
const
ptrdiff_t
*
src_strides
,
size_t
ndim
,
size_t
element_size
);
size_t
ndim
()
const
;
size_t
unit
()
const
;
size_t
count
()
const
;
const
ptrdiff_t
*
idx_strides
()
const
;
const
ptrdiff_t
*
dst_strides
()
const
;
const
ptrdiff_t
*
src_strides
()
const
;
void
launch
(
void
*
dst
,
const
void
*
src
)
const
;
};
void
rearrange
(
void
*
dst
,
const
void
*
src
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment