Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
ee68b55d
Commit
ee68b55d
authored
Mar 26, 2025
by
Zimin Li
Browse files
issue/46: Move tensor broadcast check functions to infiniopTensorDescriptor_t
parent
baafb916
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
26 deletions
+32
-26
src/infiniop/binary/binary.h
src/infiniop/binary/binary.h
+9
-26
src/infiniop/tensor.h
src/infiniop/tensor.h
+4
-0
src/infiniop/tensor_descriptor.cc
src/infiniop/tensor_descriptor.cc
+19
-0
No files found.
src/infiniop/binary/binary.h
View file @
ee68b55d
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include "../operator.h"
#include "../operator.h"
#include "../tensor.h"
#include "../tensor.h"
#include <algorithm>
#include <numeric>
#include <numeric>
/**
/**
...
@@ -73,39 +72,23 @@ inline infiniStatus_t createBinaryInfo(BinaryInfo &info,
...
@@ -73,39 +72,23 @@ inline infiniStatus_t createBinaryInfo(BinaryInfo &info,
return
INFINI_STATUS_BAD_PARAM
;
return
INFINI_STATUS_BAD_PARAM
;
}
}
const
auto
&
c_shape
=
c_desc
->
shape
();
info
.
c_data_size
=
c_desc
->
numel
();
const
auto
&
a_shape
=
a_desc
->
shape
();
const
auto
&
b_shape
=
b_desc
->
shape
();
const
auto
&
c_strides
=
c_desc
->
strides
();
const
auto
&
a_strides
=
a_desc
->
strides
();
const
auto
&
b_strides
=
b_desc
->
strides
();
info
.
c_data_size
=
std
::
accumulate
(
c_shape
.
begin
(),
c_shape
.
end
(),
size_t
(
1
),
std
::
multiplies
<
size_t
>
());
info
.
ndim
=
c_desc
->
ndim
();
info
.
ndim
=
c_desc
->
ndim
();
info
.
contiguous
=
c_desc
->
isContiguous
()
&&
a_desc
->
isContiguous
()
&&
b_desc
->
isContiguous
();
info
.
contiguous
=
c_desc
->
isContiguous
()
&&
a_desc
->
isContiguous
()
&&
b_desc
->
isContiguous
();
// Check if a tensor is broadcasted by checking its shape and strides
auto
isBroadcasted
=
[](
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
ptrdiff_t
>
&
strides
)
{
return
std
::
any_of
(
shape
.
begin
(),
shape
.
end
(),
[
&
,
i
=
0
](
const
auto
&
)
mutable
{
return
shape
[
i
]
!=
1
&&
strides
[
i
++
]
==
0
;
});
};
// Destination cannot have broadcast setup
// Destination cannot have broadcast setup
if
(
isBroadcasted
(
c_shape
,
c_strides
))
{
if
(
c_desc
->
hasBroadcastDim
(
))
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
}
const
bool
ndim_match
=
(
c_desc
->
ndim
()
==
a_desc
->
ndim
())
&&
(
c_desc
->
ndim
()
==
b_desc
->
ndim
());
const
bool
ndim_match
=
(
c_desc
->
ndim
()
==
a_desc
->
ndim
())
&&
(
c_desc
->
ndim
()
==
b_desc
->
ndim
());
info
.
broadcasted
=
!
info
.
contiguous
&&
(
!
ndim_match
||
isBroadcasted
(
a_shape
,
a_strides
)
||
isBroadcasted
(
b_shape
,
b_strides
));
info
.
broadcasted
=
!
info
.
contiguous
&&
(
!
ndim_match
||
a_desc
->
hasBroadcastDim
()
||
b_desc
->
hasBroadcastDim
(
));
info
.
c_shape
=
std
::
move
(
c_shape
);
info
.
c_shape
=
std
::
move
(
c_
desc
->
shape
()
);
info
.
a_shape
=
std
::
move
(
a_shape
);
info
.
a_shape
=
std
::
move
(
a_
desc
->
shape
()
);
info
.
b_shape
=
std
::
move
(
b_shape
);
info
.
b_shape
=
std
::
move
(
b_
desc
->
shape
()
);
info
.
c_strides
=
std
::
move
(
c_strides
);
info
.
c_strides
=
std
::
move
(
c_
desc
->
strides
()
);
info
.
a_strides
=
std
::
move
(
a_strides
);
info
.
a_strides
=
std
::
move
(
a_
desc
->
strides
()
);
info
.
b_strides
=
std
::
move
(
b_strides
);
info
.
b_strides
=
std
::
move
(
b_
desc
->
strides
()
);
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
...
...
src/infiniop/tensor.h
View file @
ee68b55d
...
@@ -28,6 +28,10 @@ public:
...
@@ -28,6 +28,10 @@ public:
bool
isContiguous
()
const
;
bool
isContiguous
()
const
;
size_t
numel
()
const
;
size_t
numel
()
const
;
// a dim is broadcasted if it's corresponding stride is 0 but dim > 1
bool
hasBroadcastDim
()
const
;
std
::
vector
<
size_t
>
getBroadcastDim
()
const
;
infiniopTensorDescriptor_t
dimMerge
(
size_t
dim_start
,
size_t
dim_end
)
const
;
infiniopTensorDescriptor_t
dimMerge
(
size_t
dim_start
,
size_t
dim_end
)
const
;
infiniopTensorDescriptor_t
dimSplit
(
size_t
axis
,
const
std
::
vector
<
size_t
>
&
dims
)
const
;
infiniopTensorDescriptor_t
dimSplit
(
size_t
axis
,
const
std
::
vector
<
size_t
>
&
dims
)
const
;
infiniopTensorDescriptor_t
dimPermute
(
const
std
::
vector
<
size_t
>
&
order
)
const
;
infiniopTensorDescriptor_t
dimPermute
(
const
std
::
vector
<
size_t
>
&
order
)
const
;
...
...
src/infiniop/tensor_descriptor.cc
View file @
ee68b55d
#include "../utils.h"
#include "../utils.h"
#include "tensor.h"
#include "tensor.h"
#include <algorithm>
#include <cstring>
#include <cstring>
#include <functional>
#include <functional>
#include <numeric>
#include <numeric>
...
@@ -85,6 +86,24 @@ bool InfiniopTensorDescriptor::isContiguous() const {
...
@@ -85,6 +86,24 @@ bool InfiniopTensorDescriptor::isContiguous() const {
return
isContiguous
(
0
,
ndim
()
-
1
);
return
isContiguous
(
0
,
ndim
()
-
1
);
}
}
bool
InfiniopTensorDescriptor
::
hasBroadcastDim
()
const
{
return
std
::
any_of
(
_shape
.
begin
(),
_shape
.
end
(),
[
&
,
i
=
0
](
const
auto
&
)
mutable
{
return
_shape
[
i
]
!=
1
&&
_strides
[
i
++
]
==
0
;
});
}
std
::
vector
<
size_t
>
InfiniopTensorDescriptor
::
getBroadcastDim
()
const
{
std
::
vector
<
size_t
>
res
;
for
(
size_t
i
=
0
;
i
<
ndim
();
++
i
)
{
if
(
_shape
[
i
]
!=
1
&&
_strides
[
i
]
==
0
)
{
res
.
push_back
(
i
);
}
}
return
res
;
}
infiniopTensorDescriptor_t
InfiniopTensorDescriptor
::
dimMerge
(
size_t
dim_start
,
size_t
dim_end
)
const
{
infiniopTensorDescriptor_t
InfiniopTensorDescriptor
::
dimMerge
(
size_t
dim_start
,
size_t
dim_end
)
const
{
if
(
dim_start
>
dim_end
||
dim_end
>=
ndim
())
{
if
(
dim_start
>
dim_end
||
dim_end
>=
ndim
())
{
return
nullptr
;
return
nullptr
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment