Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
baafb916
Commit
baafb916
authored
Mar 26, 2025
by
Zimin Li
Browse files
issue/46: change BinaryInfo create method, update binary contiguous check
parent
377f6e20
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
42 deletions
+57
-42
src/infiniop/binary/binary.h
src/infiniop/binary/binary.h
+47
-28
src/infiniop/binary/cpu/binary_cpu.h
src/infiniop/binary/cpu/binary_cpu.h
+7
-10
src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc
src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc
+3
-4
No files found.
src/infiniop/binary/binary.h
View file @
baafb916
#ifndef __INFINIOP_BINARY_H__
#define __INFINIOP_BINARY_H__
#include "../devices/cpu/common_cpu.h"
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
#include <numeric>
/**
...
...
@@ -52,24 +52,9 @@ namespace op::binary {
// Stores metadata for binary operations on CPU
struct
BinaryInfo
{
private:
BinaryInfo
(
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
:
ndim
(
c_desc
->
ndim
()),
c_shape
(
std
::
move
(
c_desc
->
shape
())),
a_shape
(
std
::
move
(
a_desc
->
shape
())),
b_shape
(
std
::
move
(
b_desc
->
shape
())),
c_strides
(
std
::
move
(
c_desc
->
strides
())),
a_strides
(
std
::
move
(
a_desc
->
strides
())),
b_strides
(
std
::
move
(
b_desc
->
strides
()))
{
this
->
c_data_size
=
std
::
accumulate
(
c_shape
.
begin
(),
c_shape
.
end
(),
size_t
(
1
),
std
::
multiplies
<
size_t
>
());
this
->
broadcasted
=
(
a_strides
!=
c_strides
)
||
(
b_strides
!=
c_strides
);
}
public:
size_t
c_data_size
;
size_t
ndim
;
bool
contiguous
;
bool
broadcasted
;
std
::
vector
<
size_t
>
c_shape
;
std
::
vector
<
size_t
>
a_shape
;
...
...
@@ -77,20 +62,54 @@ public:
std
::
vector
<
ptrdiff_t
>
c_strides
;
std
::
vector
<
ptrdiff_t
>
a_strides
;
std
::
vector
<
ptrdiff_t
>
b_strides
;
};
static
infiniStatus_t
create
(
BinaryInfo
**
instance
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
if
(
!
c_desc
||
!
a_desc
||
!
b_desc
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
inline
infiniStatus_t
createBinaryInfo
(
BinaryInfo
&
info
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
*
instance
=
new
BinaryInfo
(
c_desc
,
a_desc
,
b_desc
)
;
return
INFINI_STATUS_
SUCCESS
;
if
(
!
c_desc
||
!
a_desc
||
!
b_desc
)
{
return
INFINI_STATUS_
BAD_PARAM
;
}
};
const
auto
&
c_shape
=
c_desc
->
shape
();
const
auto
&
a_shape
=
a_desc
->
shape
();
const
auto
&
b_shape
=
b_desc
->
shape
();
const
auto
&
c_strides
=
c_desc
->
strides
();
const
auto
&
a_strides
=
a_desc
->
strides
();
const
auto
&
b_strides
=
b_desc
->
strides
();
info
.
c_data_size
=
std
::
accumulate
(
c_shape
.
begin
(),
c_shape
.
end
(),
size_t
(
1
),
std
::
multiplies
<
size_t
>
());
info
.
ndim
=
c_desc
->
ndim
();
info
.
contiguous
=
c_desc
->
isContiguous
()
&&
a_desc
->
isContiguous
()
&&
b_desc
->
isContiguous
();
// Check if a tensor is broadcasted by checking its shape and strides
auto
isBroadcasted
=
[](
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
ptrdiff_t
>
&
strides
)
{
return
std
::
any_of
(
shape
.
begin
(),
shape
.
end
(),
[
&
,
i
=
0
](
const
auto
&
)
mutable
{
return
shape
[
i
]
!=
1
&&
strides
[
i
++
]
==
0
;
});
};
// Destination cannot have broadcast setup
if
(
isBroadcasted
(
c_shape
,
c_strides
))
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
const
bool
ndim_match
=
(
c_desc
->
ndim
()
==
a_desc
->
ndim
())
&&
(
c_desc
->
ndim
()
==
b_desc
->
ndim
());
info
.
broadcasted
=
!
info
.
contiguous
&&
(
!
ndim_match
||
isBroadcasted
(
a_shape
,
a_strides
)
||
isBroadcasted
(
b_shape
,
b_strides
));
info
.
c_shape
=
std
::
move
(
c_shape
);
info
.
a_shape
=
std
::
move
(
a_shape
);
info
.
b_shape
=
std
::
move
(
b_shape
);
info
.
c_strides
=
std
::
move
(
c_strides
);
info
.
a_strides
=
std
::
move
(
a_strides
);
info
.
b_strides
=
std
::
move
(
b_strides
);
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::binary
#endif // __INFINIOP_BINARY_H__
src/infiniop/binary/cpu/binary_cpu.h
View file @
baafb916
#ifndef __INFINIOP_BINARY_CPU_H__
#define __INFINIOP_BINARY_CPU_H__
#include "../../devices/cpu/common_cpu.h"
#include "../binary.h"
#include <utility>
...
...
@@ -18,11 +19,9 @@ void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void *
#pragma omp parallel for
for
(
ptrdiff_t
i
=
0
;
i
<
data_size
;
++
i
)
{
size_t
a_index
=
info
.
broadcasted
?
op
::
common_cpu
::
indexToReducedOffset
(
i
,
info
.
ndim
,
info
.
c_strides
.
data
(),
info
.
a_strides
.
data
())
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
a_shape
.
data
(),
info
.
a_strides
.
data
());
size_t
b_index
=
info
.
broadcasted
?
op
::
common_cpu
::
indexToReducedOffset
(
i
,
info
.
ndim
,
info
.
c_strides
.
data
(),
info
.
b_strides
.
data
())
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
b_shape
.
data
(),
info
.
b_strides
.
data
());
size_t
c_index
=
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
c_shape
.
data
(),
info
.
c_strides
.
data
());
size_t
a_index
=
info
.
contiguous
?
i
:
(
info
.
broadcasted
?
op
::
common_cpu
::
indexToReducedOffset
(
i
,
info
.
ndim
,
info
.
c_strides
.
data
(),
info
.
a_strides
.
data
())
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
a_shape
.
data
(),
info
.
a_strides
.
data
()));
size_t
b_index
=
info
.
contiguous
?
i
:
(
info
.
broadcasted
?
op
::
common_cpu
::
indexToReducedOffset
(
i
,
info
.
ndim
,
info
.
c_strides
.
data
(),
info
.
b_strides
.
data
())
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
b_shape
.
data
(),
info
.
b_strides
.
data
()));
size_t
c_index
=
info
.
contiguous
?
i
:
(
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
c_shape
.
data
(),
info
.
c_strides
.
data
()));
c_
[
c_index
]
=
BinaryOp
{}(
a_
[
a_index
],
b_
[
b_index
],
std
::
forward
<
Args
>
(
args
)...);
}
...
...
@@ -38,11 +37,9 @@ void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void *
#pragma omp parallel for
for
(
ptrdiff_t
i
=
0
;
i
<
data_size
;
++
i
)
{
size_t
a_index
=
info
.
broadcasted
?
op
::
common_cpu
::
indexToReducedOffset
(
i
,
info
.
ndim
,
info
.
c_strides
.
data
(),
info
.
a_strides
.
data
())
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
a_shape
.
data
(),
info
.
a_strides
.
data
());
size_t
b_index
=
info
.
broadcasted
?
op
::
common_cpu
::
indexToReducedOffset
(
i
,
info
.
ndim
,
info
.
c_strides
.
data
(),
info
.
b_strides
.
data
())
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
b_shape
.
data
(),
info
.
b_strides
.
data
());
size_t
c_index
=
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
c_shape
.
data
(),
info
.
c_strides
.
data
());
size_t
a_index
=
info
.
contiguous
?
i
:
(
info
.
broadcasted
?
op
::
common_cpu
::
indexToReducedOffset
(
i
,
info
.
ndim
,
info
.
c_strides
.
data
(),
info
.
a_strides
.
data
())
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
a_shape
.
data
(),
info
.
a_strides
.
data
()));
size_t
b_index
=
info
.
contiguous
?
i
:
(
info
.
broadcasted
?
op
::
common_cpu
::
indexToReducedOffset
(
i
,
info
.
ndim
,
info
.
c_strides
.
data
(),
info
.
b_strides
.
data
())
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
b_shape
.
data
(),
info
.
b_strides
.
data
()));
size_t
c_index
=
info
.
contiguous
?
i
:
(
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
ndim
,
info
.
c_shape
.
data
(),
info
.
c_strides
.
data
()));
if
constexpr
(
std
::
is_same_v
<
Tdata
,
fp16_t
>
)
{
float
a_val
=
utils
::
cast
<
float
>
(
a_
[
a_index
]);
...
...
src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc
View file @
baafb916
...
...
@@ -22,18 +22,17 @@ infiniStatus_t Descriptor::create(
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
op
::
binary
::
BinaryInfo
*
info
=
nullptr
;
CHECK_STATUS
(
op
::
binary
::
BinaryInfo
::
create
(
&
info
,
out_desc
,
up_desc
,
gate_desc
));
op
::
binary
::
BinaryInfo
info
;
CHECK_STATUS
(
op
::
binary
::
create
BinaryInfo
(
info
,
out_desc
,
up_desc
,
gate_desc
));
// Create descriptor
*
desc_ptr
=
new
Descriptor
(
dtype
,
*
info
,
std
::
move
(
info
)
,
nullptr
,
handle
->
device
,
handle
->
device_id
);
delete
info
;
return
INFINI_STATUS_SUCCESS
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment