Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
c2e87202
Commit
c2e87202
authored
Jun 04, 2025
by
Catheriany
Browse files
Merge remote-tracking branch 'origin/main' into issue/142
parents
41818f84
c203635b
Changes
175
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2020 additions
and
2 deletions
+2020
-2
src/infiniop/devices/kunlun/kunlun_kernel_dtype.h
src/infiniop/devices/kunlun/kunlun_kernel_dtype.h
+22
-0
src/infiniop/devices/maca/common_maca.h
src/infiniop/devices/maca/common_maca.h
+15
-0
src/infiniop/devices/maca/maca_handle.cc
src/infiniop/devices/maca/maca_handle.cc
+23
-1
src/infiniop/devices/musa/common_musa.h
src/infiniop/devices/musa/common_musa.h
+15
-0
src/infiniop/devices/musa/musa_handle.cc
src/infiniop/devices/musa/musa_handle.cc
+23
-1
src/infiniop/elementwise/cpu/elementwise_cpu.h
src/infiniop/elementwise/cpu/elementwise_cpu.h
+201
-0
src/infiniop/elementwise/cuda/elementwise_cuda.cuh
src/infiniop/elementwise/cuda/elementwise_cuda.cuh
+419
-0
src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh
src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh
+109
-0
src/infiniop/elementwise/elementwise.h
src/infiniop/elementwise/elementwise.h
+206
-0
src/infiniop/elementwise/kunlun/elementwise_kunlun.h
src/infiniop/elementwise/kunlun/elementwise_kunlun.h
+137
-0
src/infiniop/elementwise/kunlun/elementwise_kunlun_api.h
src/infiniop/elementwise/kunlun/elementwise_kunlun_api.h
+50
-0
src/infiniop/elementwise/kunlun/elementwise_kunlun_kernel.h
src/infiniop/elementwise/kunlun/elementwise_kunlun_kernel.h
+192
-0
src/infiniop/ops/add/cpu/add_cpu.cc
src/infiniop/ops/add/cpu/add_cpu.cc
+52
-0
src/infiniop/ops/add/cpu/add_cpu.h
src/infiniop/ops/add/cpu/add_cpu.h
+19
-0
src/infiniop/ops/add/cuda/add_cuda.cu
src/infiniop/ops/add/cuda/add_cuda.cu
+57
-0
src/infiniop/ops/add/cuda/add_cuda.cuh
src/infiniop/ops/add/cuda/add_cuda.cuh
+8
-0
src/infiniop/ops/add/cuda/add_cuda_internal.cuh
src/infiniop/ops/add/cuda/add_cuda_internal.cuh
+26
-0
src/infiniop/ops/add/operator.cc
src/infiniop/ops/add/operator.cc
+118
-0
src/infiniop/ops/attention/attention.h
src/infiniop/ops/attention/attention.h
+37
-0
src/infiniop/ops/attention/operator.cc
src/infiniop/ops/attention/operator.cc
+291
-0
No files found.
src/infiniop/devices/kunlun/kunlun_kernel_dtype.h
0 → 100644
View file @
c2e87202
#ifndef __INFINIOP_KUNLUN_DTYPE_H__
#define __INFINIOP_KUNLUN_DTYPE_H__
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// kunlun ptrdiff_t* is used to save ptrdiff_t array
// copied from host
typedef
struct
_ptrdiff_t
{
long
value
;
// 32 bit
long
padding
;
// 32 bit
}
_ptrdiff_t
;
// same as ptrdiff
typedef
struct
_size_t
{
size_t
value
;
size_t
padding
;
}
_size_t
;
#endif
src/infiniop/devices/maca/common_maca.h
View file @
c2e87202
...
...
@@ -17,9 +17,24 @@ class Handle::Internal {
template
<
typename
T
>
using
Fn
=
std
::
function
<
infiniStatus_t
(
T
)
>
;
int
_warp_size
,
_max_threads_per_block
,
_block_size
[
3
],
_grid_size
[
3
];
public:
Internal
(
int
);
infiniStatus_t
useMcblas
(
hcStream_t
stream
,
const
Fn
<
hcblasHandle_t
>
&
f
)
const
;
infiniStatus_t
useMcdnn
(
hcStream_t
stream
,
const
Fn
<
hcdnnHandle_t
>
&
f
)
const
;
int
warpSize
()
const
;
int
maxThreadsPerBlock
()
const
;
int
blockSizeX
()
const
;
int
blockSizeY
()
const
;
int
blockSizeZ
()
const
;
int
gridSizeX
()
const
;
int
gridSizeY
()
const
;
int
gridSizeZ
()
const
;
};
hcdnnDataType_t
getHcdnnDtype
(
infiniDtype_t
dt
);
...
...
src/infiniop/devices/maca/maca_handle.cc
View file @
c2e87202
...
...
@@ -3,7 +3,7 @@
namespace
device
::
maca
{
Handle
::
Handle
(
infiniDevice_t
device
,
int
device_id
)
:
InfiniopHandle
{
device
,
device_id
},
_internal
(
std
::
make_shared
<
Handle
::
Internal
>
())
{}
_internal
(
std
::
make_shared
<
Handle
::
Internal
>
(
device_id
))
{}
Handle
::
Handle
(
int
device_id
)
:
Handle
(
INFINI_DEVICE_METAX
,
device_id
)
{}
...
...
@@ -11,6 +11,19 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return
_internal
;
}
Handle
::
Internal
::
Internal
(
int
device_id
)
{
hcDeviceProp_t
prop
;
hcGetDeviceProperties
(
&
prop
,
device_id
);
_warp_size
=
prop
.
warpSize
;
_max_threads_per_block
=
prop
.
maxThreadsPerBlock
;
_block_size
[
0
]
=
prop
.
maxThreadsDim
[
0
];
_block_size
[
1
]
=
prop
.
maxThreadsDim
[
1
];
_block_size
[
2
]
=
prop
.
maxThreadsDim
[
2
];
_grid_size
[
0
]
=
prop
.
maxGridSize
[
0
];
_grid_size
[
1
]
=
prop
.
maxGridSize
[
1
];
_grid_size
[
2
]
=
prop
.
maxGridSize
[
2
];
}
infiniStatus_t
Handle
::
Internal
::
useMcblas
(
hcStream_t
stream
,
const
Fn
<
hcblasHandle_t
>
&
f
)
const
{
auto
handle
=
mcblas_handles
.
pop
();
if
(
!
handle
)
{
...
...
@@ -33,6 +46,15 @@ infiniStatus_t Handle::Internal::useMcdnn(hcStream_t stream, const Fn<hcdnnHandl
return
INFINI_STATUS_SUCCESS
;
}
int
Handle
::
Internal
::
warpSize
()
const
{
return
_warp_size
;
}
int
Handle
::
Internal
::
maxThreadsPerBlock
()
const
{
return
_max_threads_per_block
;
}
int
Handle
::
Internal
::
blockSizeX
()
const
{
return
_block_size
[
0
];
}
int
Handle
::
Internal
::
blockSizeY
()
const
{
return
_block_size
[
1
];
}
int
Handle
::
Internal
::
blockSizeZ
()
const
{
return
_block_size
[
2
];
}
int
Handle
::
Internal
::
gridSizeX
()
const
{
return
_grid_size
[
0
];
}
int
Handle
::
Internal
::
gridSizeY
()
const
{
return
_grid_size
[
1
];
}
int
Handle
::
Internal
::
gridSizeZ
()
const
{
return
_grid_size
[
2
];
}
hcdnnDataType_t
getHcdnnDtype
(
infiniDtype_t
dt
)
{
switch
(
dt
)
{
case
INFINI_DTYPE_F16
:
...
...
src/infiniop/devices/musa/common_musa.h
View file @
c2e87202
...
...
@@ -16,12 +16,27 @@ class Handle::Internal {
Pool
<
std
::
unique_ptr
<
mublasHandle_t
>>
mublas_handles
;
Pool
<
std
::
unique_ptr
<::
musa
::
dnn
::
Handle
>>
mudnn_handles
;
int
_warp_size
,
_max_threads_per_block
,
_block_size
[
3
],
_grid_size
[
3
];
template
<
typename
T
>
using
Fn
=
std
::
function
<
infiniStatus_t
(
T
)
>
;
public:
Internal
(
int
);
infiniStatus_t
useMublas
(
musaStream_t
stream
,
const
Fn
<
mublasHandle_t
>
&
f
)
const
;
infiniStatus_t
useMudnn
(
musaStream_t
stream
,
const
Fn
<::
musa
::
dnn
::
Handle
&>
&
f
)
const
;
int
warpSize
()
const
;
int
maxThreadsPerBlock
()
const
;
int
blockSizeX
()
const
;
int
blockSizeY
()
const
;
int
blockSizeZ
()
const
;
int
gridSizeX
()
const
;
int
gridSizeY
()
const
;
int
gridSizeZ
()
const
;
};
}
// namespace device::musa
src/infiniop/devices/musa/musa_handle.cc
View file @
c2e87202
...
...
@@ -3,7 +3,7 @@
namespace
device
::
musa
{
Handle
::
Handle
(
infiniDevice_t
device
,
int
device_id
)
:
InfiniopHandle
{
device
,
device_id
},
_internal
(
std
::
make_shared
<
Handle
::
Internal
>
())
{}
_internal
(
std
::
make_shared
<
Handle
::
Internal
>
(
device_id
))
{}
Handle
::
Handle
(
int
device_id
)
:
Handle
(
INFINI_DEVICE_MOORE
,
device_id
)
{}
...
...
@@ -11,6 +11,19 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return
_internal
;
}
Handle
::
Internal
::
Internal
(
int
device_id
)
{
musaDeviceProp
prop
;
musaGetDeviceProperties
(
&
prop
,
device_id
);
_warp_size
=
prop
.
warpSize
;
_max_threads_per_block
=
prop
.
maxThreadsPerBlock
;
_block_size
[
0
]
=
prop
.
maxThreadsDim
[
0
];
_block_size
[
1
]
=
prop
.
maxThreadsDim
[
1
];
_block_size
[
2
]
=
prop
.
maxThreadsDim
[
2
];
_grid_size
[
0
]
=
prop
.
maxGridSize
[
0
];
_grid_size
[
1
]
=
prop
.
maxGridSize
[
1
];
_grid_size
[
2
]
=
prop
.
maxGridSize
[
2
];
}
infiniStatus_t
Handle
::
Internal
::
useMublas
(
musaStream_t
stream
,
const
Fn
<
mublasHandle_t
>
&
f
)
const
{
std
::
unique_ptr
<
mublasHandle_t
>
handle
;
auto
opt_handle
=
mublas_handles
.
pop
();
...
...
@@ -40,6 +53,15 @@ infiniStatus_t Handle::Internal::useMudnn(musaStream_t stream, const Fn<::musa::
return
INFINI_STATUS_SUCCESS
;
}
int
Handle
::
Internal
::
warpSize
()
const
{
return
_warp_size
;
}
int
Handle
::
Internal
::
maxThreadsPerBlock
()
const
{
return
_max_threads_per_block
;
}
int
Handle
::
Internal
::
blockSizeX
()
const
{
return
_block_size
[
0
];
}
int
Handle
::
Internal
::
blockSizeY
()
const
{
return
_block_size
[
1
];
}
int
Handle
::
Internal
::
blockSizeZ
()
const
{
return
_block_size
[
2
];
}
int
Handle
::
Internal
::
gridSizeX
()
const
{
return
_grid_size
[
0
];
}
int
Handle
::
Internal
::
gridSizeY
()
const
{
return
_grid_size
[
1
];
}
int
Handle
::
Internal
::
gridSizeZ
()
const
{
return
_grid_size
[
2
];
}
infiniStatus_t
Handle
::
create
(
InfiniopHandle
**
handle_ptr
,
int
device_id
)
{
*
handle_ptr
=
new
Handle
(
INFINI_DEVICE_MOORE
,
device_id
);
return
INFINI_STATUS_SUCCESS
;
...
...
src/infiniop/elementwise/cpu/elementwise_cpu.h
0 → 100644
View file @
c2e87202
#ifndef __INFINIOP_ELEMENTWISE_CPU_H__
#define __INFINIOP_ELEMENTWISE_CPU_H__
#include "../../devices/cpu/common_cpu.h"
#include "../elementwise.h"
#include <utility>
/**
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CPU implementation
*
* @param HANDLE The device handle.
* @param DTYPE The output dtype.
* @param OUT_DESC The output tensor descriptor.
* @param INPUT_DESC_VEC A vector containing input tensor descriptors.
*/
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
info_result.take(), \
nullptr, \
0, \
HANDLE->device, \
HANDLE->device_id);
namespace
op
::
elementwise
::
cpu
{
/**
* @brief CPU-specific device implementation for resource management and
* calculation implementations.
*
* This class encapsulates device-specific behavior and execution logic.
* Use the static create() method to instantiate a DeviceImpl.
*/
class
DeviceImpl
final
{
struct
Opaque
;
std
::
shared_ptr
<
Opaque
>
_opaque
;
DeviceImpl
(
std
::
shared_ptr
<
Opaque
>
opaque
)
:
_opaque
(
std
::
move
(
opaque
))
{}
public:
~
DeviceImpl
()
=
default
;
template
<
typename
...
Args
>
static
utils
::
Result
<
DeviceImpl
>
create
(
Args
&&
...
args
);
/**
* @brief Dispatches an elementwise operation with uniform input types.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tdata The common data type of all inputs and output.
* @tparam Args Additional backend-specific arguments.
* @param info Precomputed tensor metadata (shapes, strides, etc.).
* @param output Pointer to the output tensor buffer.
* @param inputs Vector of input tensor data pointers.
* @param stream Device execution stream.
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template
<
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
stream
,
Args
&&
...
args
);
/**
* @brief Dispatches an elementwise operation with heterogeneous input types.
*
* Supports operations where each input may have a different type, as defined by Op.
* The number of input types must match the operation's expected input count.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
* @tparam Args Additional backend-specific arguments.
* @param info Precomputed tensor metadata (shapes, strides, etc.).
* @param output Pointer to the output tensor buffer.
* @param inputs Vector of input tensor data pointers.
* @param stream Device execution stream.
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template
<
typename
Op
,
typename
Tout
,
typename
...
Tin
,
typename
...
Args
,
std
::
enable_if_t
<
(
sizeof
...(
Tin
)
==
Op
::
num_inputs
),
int
>
=
0
>
infiniStatus_t
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
stream
,
Args
&&
...
args
);
};
// Define the Opaque struct for CPU, which is empty
struct
DeviceImpl
::
Opaque
{};
template
<
typename
...
Args
>
utils
::
Result
<
DeviceImpl
>
DeviceImpl
::
create
(
Args
&&
...
args
)
{
return
INFINI_STATUS_NOT_IMPLEMENTED
;
}
// Perform elementwise operation for different input types
template
<
typename
Op
,
typename
Tout
,
typename
...
Tin
,
size_t
...
Is
,
typename
...
Args
,
std
::
enable_if_t
<
(
sizeof
...(
Tin
)
==
Op
::
num_inputs
),
int
>
=
0
>
void
calculate_impl
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
std
::
index_sequence
<
Is
...
>
,
Args
&&
...
args
)
{
Tout
*
out
=
reinterpret_cast
<
Tout
*>
(
output
);
std
::
tuple
<
const
Tin
*
...
>
input_ptrs
=
{
reinterpret_cast
<
const
Tin
*>
(
inputs
[
Is
])...};
ptrdiff_t
output_size
=
info
.
getOutputSize
();
#pragma omp parallel for
for
(
ptrdiff_t
i
=
0
;
i
<
output_size
;
++
i
)
{
size_t
out_idx
=
info
.
isOutputContiguous
()
?
i
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
getNdim
(),
info
.
getOutputShape
(),
info
.
getOutputStrides
());
auto
get_input_idx
=
[
&
](
size_t
input_id
)
{
return
info
.
getInputContiguous
()[
input_id
]
?
i
:
(
info
.
getInputBroadcasted
()[
input_id
]
?
op
::
common_cpu
::
indexToReducedOffset
(
i
,
info
.
getNdim
(),
info
.
getOutputStrides
(),
info
.
getInputStrides
(
input_id
))
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
getNdim
(),
info
.
getInputShape
(
input_id
),
info
.
getInputStrides
(
input_id
)));
};
out
[
out_idx
]
=
utils
::
cast
<
Tout
>
(
Op
{}.
template
operator
()
<
Tout
,
Tin
...>(
std
::
get
<
Is
>
(
input_ptrs
)[
get_input_idx
(
Is
)]...,
std
::
forward
<
Args
>
(
args
)...));
}
}
// Invoke elementwise operation for different input types
template
<
typename
Op
,
typename
Tout
,
typename
...
Tin
,
typename
...
Args
,
std
::
enable_if_t
<
(
sizeof
...(
Tin
)
==
Op
::
num_inputs
),
int
>
>
infiniStatus_t
DeviceImpl
::
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
stream
,
Args
&&
...
args
)
{
static_assert
(
sizeof
...(
Tin
)
==
Op
::
num_inputs
,
"Input type count mismatch"
);
calculate_impl
<
Op
,
Tout
,
Tin
...
>
(
info
,
output
,
inputs
,
std
::
make_index_sequence
<
sizeof
...(
Tin
)
>
{},
std
::
forward
<
Args
>
(
args
)...);
return
INFINI_STATUS_SUCCESS
;
}
// Perform elementwise operation when all inputs have the same type
template
<
typename
Op
,
typename
Tdata
,
size_t
...
Is
,
typename
...
Args
>
void
calculate_impl
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
std
::
index_sequence
<
Is
...
>
,
Args
&&
...
args
)
{
Tdata
*
out
=
reinterpret_cast
<
Tdata
*>
(
output
);
std
::
array
<
const
Tdata
*
,
sizeof
...(
Is
)
>
ins
=
{
reinterpret_cast
<
const
Tdata
*>
(
inputs
[
Is
])...};
const
ptrdiff_t
output_size
=
info
.
getOutputSize
();
#pragma omp parallel for
for
(
ptrdiff_t
i
=
0
;
i
<
output_size
;
++
i
)
{
size_t
out_idx
=
info
.
isOutputContiguous
()
?
i
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
getNdim
(),
info
.
getOutputShape
(),
info
.
getOutputStrides
());
auto
get_input_idx
=
[
&
](
size_t
input_id
)
{
return
info
.
getInputContiguous
()[
input_id
]
?
i
:
(
info
.
getInputBroadcasted
()[
input_id
]
?
op
::
common_cpu
::
indexToReducedOffset
(
i
,
info
.
getNdim
(),
info
.
getOutputStrides
(),
info
.
getInputStrides
(
input_id
))
:
op
::
common_cpu
::
indexToOffset
(
i
,
info
.
getNdim
(),
info
.
getInputShape
(
input_id
),
info
.
getInputStrides
(
input_id
)));
};
if
constexpr
(
std
::
is_same_v
<
Tdata
,
fp16_t
>
)
{
out
[
out_idx
]
=
utils
::
cast
<
fp16_t
>
(
Op
{}(
utils
::
cast
<
float
>
(
ins
[
Is
][
get_input_idx
(
Is
)])...,
std
::
forward
<
Args
>
(
args
)...));
}
else
{
out
[
out_idx
]
=
Op
{}(
ins
[
Is
][
get_input_idx
(
Is
)]...,
std
::
forward
<
Args
>
(
args
)...);
}
}
}
// Invoke elementwise operation when all inputs have the same type
template
<
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
DeviceImpl
::
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
stream
,
Args
&&
...
args
)
{
constexpr
size_t
N
=
Op
::
num_inputs
;
calculate_impl
<
Op
,
Tdata
>
(
info
,
output
,
inputs
,
std
::
make_index_sequence
<
N
>
{},
std
::
forward
<
Args
>
(
args
)...);
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::elementwise::cpu
#endif // __INFINIOP_ELEMENTWISE_CPU_H__
src/infiniop/elementwise/cuda/elementwise_cuda.cuh
0 → 100644
View file @
c2e87202
#ifndef __INFINIOP_ELEMENTWISE_CUDA_H__
#define __INFINIOP_ELEMENTWISE_CUDA_H__
#include "../../../utils.h"
#include "../../devices/cuda/cuda_common.cuh"
#include "../../devices/cuda/cuda_kernel_common.cuh"
#include "elementwise_cuda_api.cuh"
namespace
op
::
elementwise
::
cuda
{
/**
* @brief Casts an untyped device pointer to a typed pointer of type T.
*
* @tparam T Desired pointer type.
*
* @param ptr Untyped pointer.
* @return Pointer of type const T*.
*/
template
<
typename
T
>
__device__
__forceinline__
const
T
*
typedInputPtr
(
const
void
*
ptr
)
{
return
reinterpret_cast
<
const
T
*>
(
ptr
);
}
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
__device__
__forceinline__
size_t
getOutputIndex
(
size_t
idx
,
bool
is_contiguous
,
size_t
ndim
,
const
size_t
*
shape
,
const
ptrdiff_t
*
strides
)
{
return
is_contiguous
?
idx
:
device
::
cuda
::
indexToOffset
(
idx
,
ndim
,
shape
,
strides
);
}
/**
* @brief Computes input element offset for broadcasting and strided access.
*
* Used to map a linear output index to the corresponding index in an input tensor,
* considering contiguity and broadcasting.
*/
struct
InputIndexer
{
size_t
idx
;
size_t
ndim
;
const
bool
*
input_contiguous
;
const
bool
*
input_broadcasted
;
const
size_t
*
input_shapes
;
const
ptrdiff_t
*
input_strides
;
const
ptrdiff_t
*
output_strides
;
/**
* @brief Computes the memory offset for a given input tensor at current index.
*
* @param input_id ID of the input tensor.
* @return Offset into the input tensor.
*/
__device__
__forceinline__
size_t
operator
()(
size_t
input_id
)
const
{
return
input_contiguous
[
input_id
]
?
idx
:
(
input_broadcasted
[
input_id
]
?
device
::
cuda
::
indexToReducedOffset
(
idx
,
ndim
,
output_strides
,
input_strides
+
input_id
*
ndim
)
:
device
::
cuda
::
indexToOffset
(
idx
,
ndim
,
input_shapes
+
input_id
*
ndim
,
input_strides
+
input_id
*
ndim
));
}
};
/**
* @brief Invokes a callable with compile-time index constants.
*
* Used to unpack index sequence for variadic template processing of inputs.
*
* @tparam F Callable type.
* @tparam Is Compile-time index sequence.
*
* @param f Callable to invoke with index constants.
*/
template
<
typename
F
,
size_t
...
Is
>
__device__
__forceinline__
void
unpackInputsAndApply
(
F
&&
f
,
std
::
index_sequence
<
Is
...
>
)
{
f
(
std
::
integral_constant
<
size_t
,
Is
>
{}...);
}
/**
* @brief CUDA kernel for performing elementwise operations on tensors where all inputs share the same data type.
*
* @tparam N Number of input tensors.
* @tparam Op Operator type implementing operator()(Tdata...).
* @tparam Tdata Common data type for inputs and output.
* @tparam Args Additional arguments to pass to the operator.
*
* @param output_size Total number of output elements.
* @param ndim Number of dimensions in tensors.
* @param output_contiguous Whether the output tensor is contiguous in memory.
* @param input_contiguous Array indicating if each input tensor is contiguous.
* @param input_broadcasted Array indicating if each input tensor is broadcasted.
* @param output_shape Shape of the output tensor.
* @param input_shapes Shapes of the input tensors.
* @param output_strides Strides for the output tensor.
* @param input_strides Strides for each input tensor.
* @param output Output buffer.
* @param inputs Array of input pointers, all of type Tdata.
* @param offset Linear offset to support partitioned execution.
* @param args Additional arguments passed to the operator.
*/
template
<
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
INFINIOP_CUDA_KERNEL
elementwiseKernel
(
size_t
output_size
,
size_t
ndim
,
bool
output_contiguous
,
const
bool
*
__restrict__
input_contiguous
,
const
bool
*
__restrict__
input_broadcasted
,
const
size_t
*
__restrict__
output_shape
,
const
size_t
*
__restrict__
input_shapes
,
const
ptrdiff_t
*
__restrict__
output_strides
,
const
ptrdiff_t
*
__restrict__
input_strides
,
Tdata
*
output
,
const
void
*
const
*
inputs
,
size_t
offset
,
Args
...
args
)
{
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
+
offset
;
if
(
idx
<
output_size
)
{
const
Tdata
*
const
*
typed_inputs
=
reinterpret_cast
<
const
Tdata
*
const
*>
(
inputs
);
size_t
out_idx
=
getOutputIndex
(
idx
,
output_contiguous
,
ndim
,
output_shape
,
output_strides
);
InputIndexer
indexer
{
idx
,
ndim
,
input_contiguous
,
input_broadcasted
,
input_shapes
,
input_strides
,
output_strides
};
unpackInputsAndApply
(
[
&
](
auto
...
Is
)
{
output
[
out_idx
]
=
Op
{}(
typed_inputs
[
Is
.
value
][
indexer
(
Is
.
value
)]...,
std
::
forward
<
Args
>
(
args
)...);
},
std
::
make_index_sequence
<
N
>
{});
}
}
/**
* @brief CUDA kernel for performing an elementwise operation on tensors with support
* for broadcasting and mixed data types.
*
* @tparam Op Operator type implementing a templated operator() for (Tout, Tin...).
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
*
* @param output_size Total number of output elements.
* @param ndim Number of dimensions in the tensors.
* @param output_contiguous Whether the output tensor is contiguous.
* @param input_contiguous Array indicating whether each input is contiguous.
* @param input_broadcasted Array indicating whether each input is broadcasted.
* @param output_shape Shape of the output tensor.
* @param input_shapes Shapes of the input tensors.
* @param output_strides Strides of the output tensor.
* @param input_strides Strides of the input tensors.
* @param output Pointer to the output buffer.
* @param inputs Array of untyped input pointers.
* @param offset Linear offset into the output for partitioned execution.
*/
template
<
typename
Op
,
typename
Tout
,
typename
...
Tin
>
INFINIOP_CUDA_KERNEL
elementwiseKernel
(
size_t
output_size
,
size_t
ndim
,
bool
output_contiguous
,
const
bool
*
__restrict__
input_contiguous
,
const
bool
*
__restrict__
input_broadcasted
,
const
size_t
*
__restrict__
output_shape
,
const
size_t
*
__restrict__
input_shapes
,
const
ptrdiff_t
*
__restrict__
output_strides
,
const
ptrdiff_t
*
__restrict__
input_strides
,
Tout
*
output
,
const
void
*
const
*
__restrict__
inputs
,
size_t
offset
)
{
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
+
offset
;
if
(
idx
<
output_size
)
{
size_t
out_idx
=
getOutputIndex
(
idx
,
output_contiguous
,
ndim
,
output_shape
,
output_strides
);
InputIndexer
indexer
{
idx
,
ndim
,
input_contiguous
,
input_broadcasted
,
input_shapes
,
input_strides
,
output_strides
};
unpackInputsAndApply
(
[
&
](
auto
...
Is
)
{
output
[
out_idx
]
=
Op
{}.
template
operator
()
<
Tout
,
Tin
...>(
(
typedInputPtr
<
Tin
>
(
inputs
[
Is
.
value
])[
indexer
(
Is
.
value
)])...);
},
std
::
index_sequence_for
<
Tin
...
>
{});
}
}
struct
DeviceImpl
::
Opaque
{
std
::
shared_ptr
<
device
::
cuda
::
Handle
::
Internal
>
internal
;
Opaque
(
const
std
::
shared_ptr
<
device
::
cuda
::
Handle
::
Internal
>
&
internal
)
:
internal
(
internal
)
{}
/**
* @brief Executes an elementwise operation where all inputs and the output share the same data type.
*
* @tparam BLOCK_SIZE CUDA block size used for kernel launch.
* @tparam N Number of input tensors.
* @tparam Op Functor representing the elementwise operation.
* @tparam Tdata Data type of both input and output tensors.
* @tparam Args Optional additional arguments passed to the operation.
*
* @param info Metadata about the operation including shape, size, and dimensionality.
* @param workspace Temporary workspace used for storing metadata on device.
* @param output Pointer to the output buffer.
* @param inputs Vector of pointers to input buffers.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/
template
<
uint32_t
BLOCK_SIZE
,
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
calculateImpl
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
cudaStream_t
stream
,
Args
&&
...
args
)
{
return
launchElementwiseKernel
<
BLOCK_SIZE
,
N
>
(
info
,
workspace
,
reinterpret_cast
<
Tdata
*>
(
output
),
inputs
,
elementwiseKernel
<
N
,
Op
,
Tdata
,
Args
...
>
,
stream
,
std
::
forward
<
Args
>
(
args
)...);
}
/**
* @brief Executes an elementwise operation with mixed input and output data types.
*
* @tparam BLOCK_SIZE CUDA block size used for kernel launch.
* @tparam N Number of input tensors.
* @tparam Op Functor representing the elementwise operation.
* @tparam Tout Data type of the output tensor.
* @tparam Tin... Data types of the input tensors.
* @tparam Args Optional additional arguments passed to the operation.(UNUSED)
*
* @param info Metadata about the operation including shape, size, and dimensionality.
* @param workspace Temporary workspace used for storing metadata on device.
* @param output Pointer to the output buffer.
* @param inputs Vector of pointers to input buffers.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/
template
<
uint32_t
BLOCK_SIZE
,
size_t
N
,
typename
Op
,
typename
Tout
,
typename
...
Tin
,
typename
...
Args
,
std
::
enable_if_t
<
(
sizeof
...(
Tin
)
==
Op
::
num_inputs
),
int
>
=
0
>
infiniStatus_t
calculateImpl
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
cudaStream_t
stream
,
Args
&&
...
args
)
{
return
launchElementwiseKernel
<
BLOCK_SIZE
,
N
>
(
info
,
workspace
,
reinterpret_cast
<
Tout
*>
(
output
),
inputs
,
elementwiseKernel
<
Op
,
Tout
,
Tin
...
>
,
stream
);
}
private:
/**
* @brief Transfers elementwise operation metadata and input pointers from host to device memory.
*
* @tparam N Number of input tensors.
*
* @param info Elementwise operation metadata (shapes, strides, flags, etc.).
* @param workspace Pointer to device workspace memory for storing metadata and input pointers.
* @param h_inputs_arr Host array of input tensor pointers.
* @param d_inputs_arr Output reference to device array of input tensor pointers.
* @param d_input_contiguous Output reference to device array indicating whether each input is contiguous.
* @param d_input_broadcasted Output reference to device array indicating whether each input is broadcasted.
* @param d_output_shape Output reference to device array holding the output tensor shape.
* @param d_output_strides Output reference to device array holding output tensor strides.
* @param d_input_shapes Output reference to flattened input tensor shapes (N * ndim).
* @param d_input_strides Output reference to flattened input tensor strides (N * ndim).
* @param stream CUDA stream used for asynchronous memory transfer.
* @return infiniStatus_t Status indicating success or failure of the memory transfer and setup.
*/
template
<
size_t
N
>
infiniStatus_t
infoToDevice
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
const
void
*
const
*
h_inputs_arr
,
const
void
**&
d_inputs_arr
,
const
bool
*&
d_input_contiguous
,
const
bool
*&
d_input_broadcasted
,
const
size_t
*&
d_output_shape
,
const
ptrdiff_t
*&
d_output_strides
,
const
size_t
*&
d_input_shapes
,
const
ptrdiff_t
*&
d_input_strides
,
cudaStream_t
stream
)
const
{
constexpr
auto
input_size
=
N
;
const
auto
ndim
=
info
.
getNdim
();
constexpr
auto
input_arr_size
=
N
*
sizeof
(
*
h_inputs_arr
);
const
int8_t
*
info_meta_start
=
info
.
getMetaStart
();
const
int8_t
*
d_meta_start
=
reinterpret_cast
<
int8_t
*>
(
workspace
)
+
input_arr_size
;
// copy the input pointer array and meta to device
CHECK_CUDA
(
cudaMemcpyAsync
(
workspace
,
h_inputs_arr
,
input_arr_size
,
cudaMemcpyHostToDevice
,
stream
));
CHECK_CUDA
(
cudaMemcpyAsync
((
void
*
)
d_meta_start
,
info_meta_start
,
info
.
getMetaMemSize
(),
cudaMemcpyHostToDevice
,
stream
));
// offset/assign the pointers
d_inputs_arr
=
reinterpret_cast
<
const
void
**>
(
workspace
);
d_output_shape
=
reinterpret_cast
<
const
size_t
*>
(
d_meta_start
);
d_output_strides
=
reinterpret_cast
<
const
ptrdiff_t
*>
(
d_output_shape
+
ndim
);
d_input_shapes
=
reinterpret_cast
<
const
size_t
*>
(
d_output_strides
+
ndim
);
d_input_strides
=
reinterpret_cast
<
const
ptrdiff_t
*>
(
d_input_shapes
+
input_size
*
ndim
);
d_input_contiguous
=
reinterpret_cast
<
const
bool
*>
(
d_input_strides
+
input_size
*
ndim
);
d_input_broadcasted
=
reinterpret_cast
<
const
bool
*>
(
d_input_contiguous
+
input_size
);
return
INFINI_STATUS_SUCCESS
;
}
/**
* @brief Launches the elementwise kernel for the specified operation.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam N Number of input tensors.
* @tparam KernelFunc Type of the kernel function pointer.
* @tparam Tout Output data type.
* @tparam Args Additional arguments to be forwarded to the kernel.
*
* @param info Metadata about the elementwise operation (shapes, strides, etc.).
* @param workspace CUDA memory used for storing metadata.
* @param output Pointer to output buffer on device.
* @param inputs Vector of device pointers to input tensors.
* @param kernel_func Kernel function to launch.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments passed to the kernel.
* @return infiniStatus_t Status code indicating success or failure.
*/
template
<
uint32_t
BLOCK_SIZE
,
size_t
N
,
typename
KernelFunc
,
typename
Tout
,
typename
...
Args
>
infiniStatus_t
launchElementwiseKernel
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
Tout
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
KernelFunc
kernel_func
,
cudaStream_t
stream
,
Args
&&
...
args
)
{
auto
output_size
=
info
.
getOutputSize
();
if
(
output_size
==
0
)
{
return
INFINI_STATUS_SUCCESS
;
}
// Device pointers
const
void
**
d_inputs_arr
=
nullptr
;
const
bool
*
d_input_contiguous
=
nullptr
;
const
bool
*
d_input_broadcasted
=
nullptr
;
const
size_t
*
d_output_shape
=
nullptr
;
const
ptrdiff_t
*
d_output_strides
=
nullptr
;
const
size_t
*
d_input_shapes
=
nullptr
;
const
ptrdiff_t
*
d_input_strides
=
nullptr
;
CHECK_STATUS
(
infoToDevice
<
N
>
(
info
,
workspace
,
inputs
.
data
(),
d_inputs_arr
,
d_input_contiguous
,
d_input_broadcasted
,
d_output_shape
,
d_output_strides
,
d_input_shapes
,
d_input_strides
,
stream
));
dim3
blockDims
(
std
::
min
(
BLOCK_SIZE
,
static_cast
<
uint32_t
>
(
internal
->
maxThreadsPerBlock
())));
dim3
gridDims
(
std
::
min
(
uint32_t
(
CEIL_DIV
(
output_size
,
blockDims
.
x
)),
static_cast
<
uint32_t
>
(
internal
->
gridSizeX
())));
size_t
step
=
gridDims
.
x
*
blockDims
.
x
;
for
(
size_t
i
=
0
;
i
<
output_size
;
i
+=
step
)
{
kernel_func
<<<
gridDims
,
blockDims
,
0
,
stream
>>>
(
output_size
,
info
.
getNdim
(),
info
.
isOutputContiguous
(),
d_input_contiguous
,
d_input_broadcasted
,
d_output_shape
,
d_input_shapes
,
d_output_strides
,
d_input_strides
,
output
,
reinterpret_cast
<
const
void
**>
(
d_inputs_arr
),
i
,
std
::
forward
<
Args
>
(
args
)...);
}
return
INFINI_STATUS_SUCCESS
;
}
};
template
<
typename
...
Args
>
utils
::
Result
<
DeviceImpl
*>
DeviceImpl
::
create
(
Args
&&
...
args
)
{
auto
opaque
=
std
::
make_shared
<
Opaque
>
(
std
::
forward
<
Args
>
(
args
)...);
return
utils
::
Result
<
DeviceImpl
*>
(
new
DeviceImpl
(
opaque
));
}
/* Invoke elementwise operation for different input types */
template
<
unsigned
int
BLOCK_SIZE
,
typename
Op
,
typename
Tout
,
typename
...
Tin
,
typename
...
Args
,
std
::
enable_if_t
<
(
sizeof
...(
Tin
)
==
Op
::
num_inputs
),
int
>
>
infiniStatus_t
DeviceImpl
::
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
stream
,
Args
&&
...
args
)
{
constexpr
size_t
N
=
Op
::
num_inputs
;
static_assert
(
sizeof
...(
Tin
)
==
N
,
"Input type count mismatch"
);
return
_opaque
->
calculateImpl
<
BLOCK_SIZE
,
N
,
Op
,
Tout
,
Tin
...
>
(
info
,
workspace
,
output
,
inputs
,
reinterpret_cast
<
cudaStream_t
>
(
stream
),
std
::
forward
<
Args
>
(
args
)...);
}
/* Invoke elementwise operation when all inputs have the same dtype */
template
<
unsigned
int
BLOCK_SIZE
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
DeviceImpl
::
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
stream
,
Args
&&
...
args
)
{
constexpr
size_t
N
=
Op
::
num_inputs
;
return
_opaque
->
calculateImpl
<
BLOCK_SIZE
,
N
,
Op
,
Tdata
>
(
info
,
workspace
,
output
,
inputs
,
reinterpret_cast
<
cudaStream_t
>
(
stream
),
std
::
forward
<
Args
>
(
args
)...);
}
}
// namespace op::elementwise::cuda
#endif // __INFINIOP_ELEMENTWISE_CUDA_H__
src/infiniop/elementwise/cuda/elementwise_cuda_api.cuh
0 → 100644
View file @
c2e87202
#ifndef __INFINIOP_ELEMENTWISE_CUDA_API_H__
#define __INFINIOP_ELEMENTWISE_CUDA_API_H__
#include "../elementwise.h"
namespace
op
::
elementwise
::
cuda
{
/**
* @brief Define the methods and info needed by CUDA to perform elementwise operation
*/
class
DeviceImpl
final
{
struct
Opaque
;
std
::
shared_ptr
<
Opaque
>
_opaque
;
DeviceImpl
(
std
::
shared_ptr
<
Opaque
>
opaque
)
:
_opaque
(
std
::
move
(
opaque
))
{}
public:
~
DeviceImpl
()
=
default
;
template
<
typename
...
Args
>
static
utils
::
Result
<
DeviceImpl
*>
create
(
Args
&&
...
args
);
/**
* @brief Launches elementwise operation where all input types are the same.
*
* Calls the corresponding templated `calculateImpl` with a unified input type.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tdata Data type for both input and output tensors.
* @tparam Args... Additional arguments passed to the operation.
*
* @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template
<
unsigned
int
BLOCK_SIZE
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
stream
,
Args
&&
...
args
);
/**
* @brief Launches elementwise operation where input types may differ.
*
* Dispatches to templated `calculateImpl` using specified output and input types.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tout Output data type.
* @tparam Tin... Input data types (must match Op::num_inputs).
* @tparam Args... Additional arguments passed to the operation.
*
* @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args (UNUSED) Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template
<
unsigned
int
BLOCK_SIZE
,
typename
Op
,
typename
Tout
,
typename
...
Tin
,
typename
...
Args
,
std
::
enable_if_t
<
(
sizeof
...(
Tin
)
==
Op
::
num_inputs
),
int
>
=
0
>
infiniStatus_t
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
stream
,
Args
&&
...
args
);
};
}
// namespace op::elementwise::cuda
/**
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CUDA implementation
*
* @param HANDLE The device handle.
* @param DTYPE The output dtype.
* @param OUT_DESC The output tensor descriptor.
* @param INPUT_DESC_VEC A vector containing input tensor descriptors.
*/
#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::cuda::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__
src/infiniop/elementwise/elementwise.h
0 → 100644
View file @
c2e87202
#ifndef __INFINIOP_ELEMENTWISE_H__
#define __INFINIOP_ELEMENTWISE_H__
#include "../../utils.h"
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
#include <array>
#include <cstring>
#include <iostream>
#include <memory>
#include <numeric>
#include <vector>
#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \
\
namespace op::OP::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
infiniDtype_t _dtype; \
op::elementwise::ElementwiseInfo _info; \
std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
size_t _workspace_size; \
\
Descriptor( \
infiniDtype_t dtype, \
op::elementwise::ElementwiseInfo info, \
op::elementwise::NAMESPACE::DeviceImpl *device_info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \
_info(std::move(info)), \
_device_info(std::move(device_info)), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t output_desc, \
std::vector<infiniopTensorDescriptor_t> input_descs); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *output, \
std::vector<const void *> inputs, \
void *stream) const; \
}; \
}
namespace
op
::
elementwise
{
/**
* @brief Stores the metadata required for performing an elementwise operation.
*
* This struct encapsulates shape, stride, and layout information for both
* output and multiple input tensors involved in an elementwise operation.
*
* Memory is manually managed and freed in the destructor.
* Supports move construction but disallows copy construction and copy/move assignment.
*
* Use ElementwiseInfo::create(...) to safely construct an instance from tensor descriptors.
*/
struct
ElementwiseInfo
{
private:
std
::
vector
<
size_t
>
_meta
;
size_t
_output_size
;
size_t
_input_size
;
size_t
_ndim
;
bool
_output_contiguous
;
ElementwiseInfo
(
std
::
vector
<
size_t
>
meta
,
size_t
output_size
,
size_t
input_size
,
size_t
ndim
,
bool
output_contiguous
)
:
_meta
(
std
::
move
(
meta
)),
_output_size
(
output_size
),
_input_size
(
input_size
),
_ndim
(
ndim
),
_output_contiguous
(
output_contiguous
)
{}
public:
// Get the Memory size of the meta data in bytes
inline
size_t
getMetaMemSize
()
const
{
return
_meta
.
size
()
*
sizeof
(
size_t
);
}
inline
const
int8_t
*
getMetaStart
()
const
{
return
reinterpret_cast
<
const
int8_t
*>
(
_meta
.
data
());
}
inline
size_t
getOutputSize
()
const
{
return
_output_size
;
}
inline
size_t
getInputSize
()
const
{
return
_input_size
;
}
inline
size_t
getNdim
()
const
{
return
_ndim
;
}
inline
bool
isOutputContiguous
()
const
{
return
_output_contiguous
;
}
inline
const
size_t
*
getOutputShape
()
const
{
return
reinterpret_cast
<
const
size_t
*>
(
_meta
.
data
());
}
inline
const
ptrdiff_t
*
getOutputStrides
()
const
{
return
reinterpret_cast
<
const
ptrdiff_t
*>
(
getOutputShape
()
+
_ndim
);
}
inline
const
size_t
*
getAllInputShapes
()
const
{
return
reinterpret_cast
<
const
size_t
*>
(
getOutputStrides
()
+
_ndim
);
}
inline
const
size_t
*
getInputShape
(
const
size_t
&
index
)
const
{
if
(
index
<
_input_size
)
{
return
reinterpret_cast
<
const
size_t
*>
(
getAllInputShapes
()
+
index
*
_ndim
);
}
return
nullptr
;
}
inline
const
ptrdiff_t
*
getAllInputStrides
()
const
{
return
reinterpret_cast
<
const
ptrdiff_t
*>
(
getAllInputShapes
()
+
_input_size
*
_ndim
);
}
inline
const
ptrdiff_t
*
getInputStrides
(
const
size_t
&
index
)
const
{
if
(
index
<
_input_size
)
{
return
reinterpret_cast
<
const
ptrdiff_t
*>
(
getAllInputStrides
()
+
index
*
_ndim
);
}
return
nullptr
;
}
inline
const
bool
*
getInputContiguous
()
const
{
return
reinterpret_cast
<
const
bool
*>
(
getAllInputStrides
()
+
_input_size
*
_ndim
);
}
inline
const
bool
*
getInputBroadcasted
()
const
{
return
reinterpret_cast
<
const
bool
*>
(
getInputContiguous
()
+
_input_size
);
}
using
ResultType
=
utils
::
Result
<
ElementwiseInfo
>
;
/**
* @brief Construct ElementwiseInfo from output and input tensor descriptors.
* @param output_desc Descriptor of the output tensor.
* @param input_descs Descriptors of the input tensors.
* @return Result<ElementwiseInfo> with the successfully constructed ElementwiseInfo,
* or the status code.
*/
static
ResultType
create
(
infiniopTensorDescriptor_t
output_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_descs
)
{
if
(
!
output_desc
||
input_descs
.
empty
())
{
return
INFINI_STATUS_BAD_PARAM
;
}
// Destination cannot have broadcast setup
if
(
output_desc
->
hasBroadcastDim
())
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
auto
input_size
=
input_descs
.
size
();
auto
ndim
=
output_desc
->
ndim
();
auto
output_size
=
output_desc
->
numel
();
auto
output_contiguous
=
output_desc
->
isContiguous
();
// Allocate memory for meta
auto
shape_unit
=
output_desc
->
dim
(
0
);
auto
stride_unit
=
output_desc
->
stride
(
0
);
size_t
meta_mem_size
=
ndim
*
(
sizeof
(
shape_unit
)
+
sizeof
(
stride_unit
))
+
input_size
*
ndim
*
sizeof
(
shape_unit
)
+
input_size
*
ndim
*
sizeof
(
stride_unit
)
+
2
*
input_size
*
sizeof
(
bool
);
std
::
vector
<
size_t
>
meta
(
CEIL_DIV
(
meta_mem_size
,
sizeof
(
size_t
)));
int8_t
*
meta_ptr
=
reinterpret_cast
<
int8_t
*>
(
meta
.
data
());
const
auto
output_shape
=
output_desc
->
shape
();
const
auto
output_strides
=
output_desc
->
strides
();
// Pointers to the sections within _meta
size_t
*
output_shape_p
=
reinterpret_cast
<
size_t
*>
(
meta_ptr
);
ptrdiff_t
*
output_strides_p
=
reinterpret_cast
<
ptrdiff_t
*>
(
output_shape_p
+
ndim
);
size_t
*
input_shapes
=
reinterpret_cast
<
size_t
*>
(
output_strides_p
+
ndim
);
ptrdiff_t
*
input_strides
=
reinterpret_cast
<
ptrdiff_t
*>
(
input_shapes
+
input_size
*
ndim
);
bool
*
input_contiguous
=
reinterpret_cast
<
bool
*>
(
input_strides
+
input_size
*
ndim
);
bool
*
input_broadcasted
=
input_contiguous
+
input_size
;
// Copy output shape and strides
std
::
memcpy
(
output_shape_p
,
output_shape
.
data
(),
ndim
*
sizeof
(
*
output_shape_p
));
std
::
memcpy
(
output_strides_p
,
output_strides
.
data
(),
ndim
*
sizeof
(
*
output_strides_p
));
// Copy input shapes, strides, contiguous, and broadcasted flags
for
(
size_t
i
=
0
;
i
<
input_size
;
++
i
)
{
auto
&
desc
=
input_descs
[
i
];
const
auto
in_shape
=
desc
->
shape
();
const
auto
in_strides
=
desc
->
strides
();
std
::
memcpy
(
input_shapes
+
i
*
ndim
,
in_shape
.
data
(),
ndim
*
sizeof
(
*
input_shapes
));
std
::
memcpy
(
input_strides
+
i
*
ndim
,
in_strides
.
data
(),
ndim
*
sizeof
(
*
input_strides
));
input_contiguous
[
i
]
=
desc
->
isContiguous
();
input_broadcasted
[
i
]
=
!
input_contiguous
[
i
]
&&
(
desc
->
ndim
()
!=
ndim
||
desc
->
hasBroadcastDim
());
}
ElementwiseInfo
info
(
std
::
move
(
meta
),
output_size
,
input_size
,
ndim
,
output_contiguous
);
return
ResultType
(
std
::
move
(
info
));
}
};
}
// namespace op::elementwise
#endif // __INFINIOP_ELEMENTWISE_H__
src/infiniop/elementwise/kunlun/elementwise_kunlun.h
0 → 100644
View file @
c2e87202
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_H__
#include "../../../utils.h"
#include "../../devices/kunlun/kunlun_handle.h"
#include "elementwise_kunlun_api.h"
namespace
op
::
elementwise
::
kunlun
{
struct
DeviceImpl
::
Opaque
{
std
::
shared_ptr
<
device
::
kunlun
::
Handle
::
Internal
>
internal
;
Opaque
(
const
std
::
shared_ptr
<
device
::
kunlun
::
Handle
::
Internal
>
&
internal_
)
:
internal
(
internal_
)
{}
template
<
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
calculateImpl
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
kunlunStream_t
stream
,
Args
&&
...
args
)
{
auto
output_size
=
info
.
getOutputSize
();
if
(
output_size
==
0
)
{
return
INFINI_STATUS_SUCCESS
;
}
// Device pointers
const
void
**
d_inputs_arr
=
nullptr
;
const
bool
*
d_input_contiguous
=
nullptr
;
const
bool
*
d_input_broadcasted
=
nullptr
;
const
size_t
*
d_output_shape
=
nullptr
;
const
ptrdiff_t
*
d_output_strides
=
nullptr
;
const
size_t
*
d_input_shapes
=
nullptr
;
const
ptrdiff_t
*
d_input_strides
=
nullptr
;
CHECK_STATUS
(
infoToDevice
<
N
>
(
info
,
workspace
,
inputs
.
data
(),
d_inputs_arr
,
d_input_contiguous
,
d_input_broadcasted
,
d_output_shape
,
d_output_strides
,
d_input_shapes
,
d_input_strides
));
Op
::
template
launch
<
Tdata
>(
output_size
,
info
.
getNdim
(),
info
.
isOutputContiguous
(),
reinterpret_cast
<
const
void
*>
(
d_input_contiguous
),
reinterpret_cast
<
const
void
*>
(
d_input_broadcasted
),
reinterpret_cast
<
const
void
*>
(
d_output_shape
),
reinterpret_cast
<
const
void
*>
(
d_input_shapes
),
reinterpret_cast
<
const
void
*>
(
d_output_strides
),
reinterpret_cast
<
const
void
*>
(
d_input_strides
),
output
,
reinterpret_cast
<
const
void
*
const
*>
(
d_inputs_arr
),
stream
,
args
...);
return
INFINI_STATUS_SUCCESS
;
}
private:
template
<
size_t
N
>
infiniStatus_t
infoToDevice
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
const
void
*
const
*
h_inputs_arr
,
const
void
**&
d_inputs_arr
,
const
bool
*&
d_input_contiguous
,
const
bool
*&
d_input_broadcasted
,
const
size_t
*&
d_output_shape
,
const
ptrdiff_t
*&
d_output_strides
,
const
size_t
*&
d_input_shapes
,
const
ptrdiff_t
*&
d_input_strides
)
const
{
constexpr
auto
input_size
=
N
;
const
auto
ndim
=
info
.
getNdim
();
constexpr
auto
input_arr_size
=
N
*
sizeof
(
*
h_inputs_arr
);
const
int8_t
*
info_meta_start
=
info
.
getMetaStart
();
const
int8_t
*
d_meta_start
=
reinterpret_cast
<
int8_t
*>
(
workspace
)
+
input_arr_size
;
// copy the input pointer array and meta to device
CHECK_KUNLUN
(
xpu_memcpy
(
workspace
,
h_inputs_arr
,
input_arr_size
,
XPU_HOST_TO_DEVICE
));
CHECK_KUNLUN
(
xpu_memcpy
((
void
*
)
d_meta_start
,
info_meta_start
,
info
.
getMetaMemSize
(),
XPU_HOST_TO_DEVICE
));
// offset/assign the pointers
d_inputs_arr
=
reinterpret_cast
<
const
void
**>
(
workspace
);
d_output_shape
=
reinterpret_cast
<
const
size_t
*>
(
d_meta_start
);
d_output_strides
=
reinterpret_cast
<
const
ptrdiff_t
*>
(
d_output_shape
+
ndim
);
d_input_shapes
=
reinterpret_cast
<
const
size_t
*>
(
d_output_strides
+
ndim
);
d_input_strides
=
reinterpret_cast
<
const
ptrdiff_t
*>
(
d_input_shapes
+
input_size
*
ndim
);
d_input_contiguous
=
reinterpret_cast
<
const
bool
*>
(
d_input_strides
+
input_size
*
ndim
);
d_input_broadcasted
=
reinterpret_cast
<
const
bool
*>
(
d_input_contiguous
+
input_size
);
return
INFINI_STATUS_SUCCESS
;
}
};
template
<
typename
...
Args
>
utils
::
Result
<
DeviceImpl
*>
DeviceImpl
::
create
(
Args
&&
...
args
)
{
auto
opaque
=
std
::
make_shared
<
Opaque
>
(
std
::
forward
<
Args
>
(
args
)...);
return
utils
::
Result
<
DeviceImpl
*>
(
new
DeviceImpl
(
opaque
));
}
template
<
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
DeviceImpl
::
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
stream
,
Args
&&
...
args
)
{
constexpr
size_t
N
=
Op
::
num_inputs
;
return
_opaque
->
calculateImpl
<
N
,
Op
,
Tdata
>
(
info
,
workspace
,
output
,
inputs
,
reinterpret_cast
<
kunlunStream_t
>
(
stream
),
std
::
forward
<
Args
>
(
args
)...);
}
}
// namespace op::elementwise::kunlun
// Template for kunlun kernel interface declaration
#define LAUNCH_ELEMENTWISE_KERNEL(OpName) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args);
#endif
src/infiniop/elementwise/kunlun/elementwise_kunlun_api.h
0 → 100644
View file @
c2e87202
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#include "../elementwise.h"
namespace
op
::
elementwise
::
kunlun
{
class
DeviceImpl
final
{
struct
Opaque
;
std
::
shared_ptr
<
Opaque
>
_opaque
;
DeviceImpl
(
std
::
shared_ptr
<
Opaque
>
opaque
)
:
_opaque
(
std
::
move
(
opaque
))
{}
public:
~
DeviceImpl
()
=
default
;
template
<
typename
...
Args
>
static
utils
::
Result
<
DeviceImpl
*>
create
(
Args
&&
...
args
);
template
<
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
stream
,
Args
&&
...
args
);
};
}
// namespace op::elementwise::kunlun
#define CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::kunlun::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif
src/infiniop/elementwise/kunlun/elementwise_kunlun_kernel.h
0 → 100644
View file @
c2e87202
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#define __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#include "../../devices/kunlun/kunlun_kernel_common.h"
using
namespace
device
::
kunlun
::
kernel
;
/**
* @brief Computes input tile offset
*/
struct
InputIndexer
{
size_t
idx
;
size_t
ndim
;
const
bool
*
input_contiguous
;
const
bool
*
input_broadcasted
;
const
_size_t
*
input_shapes
;
const
_ptrdiff_t
*
input_strides
;
const
_ptrdiff_t
*
output_strides
;
__device__
size_t
operator
()(
size_t
input_id
)
const
{
return
input_contiguous
[
input_id
]
?
idx
:
(
input_broadcasted
[
input_id
]
?
indexToReducedOffset
(
idx
,
ndim
,
output_strides
,
input_strides
+
input_id
*
ndim
)
:
indexToOffset
(
idx
,
ndim
,
input_shapes
+
input_id
*
ndim
,
input_strides
+
input_id
*
ndim
));
}
};
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
inline
__device__
size_t
getOutputIndex
(
size_t
idx
,
bool
is_contiguous
,
size_t
ndim
,
const
_size_t
*
shape
,
const
_ptrdiff_t
*
strides
)
{
return
is_contiguous
?
idx
:
indexToOffset
(
idx
,
ndim
,
shape
,
strides
);
}
template
<
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
__device__
void
launchOp
(
__global_ptr__
Tdata
**
typed_inputs
,
// gm pointer
__global_ptr__
Tdata
*
output
,
// gm pointer output
Tdata
*
inputs_buf
,
// local mem buffer
size_t
*
input_indexes
,
size_t
output_index
,
Args
...
args
)
{
static_assert
(
N
==
Op
::
num_inputs
,
"template N is not equal to Op::num_inputs!
\n
"
);
#pragma unroll
// Copy inputs to buf
for
(
size_t
i
=
0
;
i
<
N
;
i
++
)
{
auto
gm
=
typed_inputs
[
i
]
+
input_indexes
[
i
];
auto
lm
=
inputs_buf
+
i
;
GM2LM_ASYNC
(
gm
,
lm
,
1
*
sizeof
(
Tdata
));
}
mfence
();
// Calculate elementwise
// Inputs save all operands
Tdata
out
=
Op
{}(
inputs_buf
,
args
...);
// Copy out to gm
LM2GM_ASYNC
(
&
out
,
output
+
output_index
,
1
*
sizeof
(
Tdata
));
mfence
();
}
template
<
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
__global__
void
elementwiseKernel
(
size_t
output_size
,
size_t
ndim
,
bool
output_contiguous
,
const
bool
*
input_contiguous_gm
,
const
bool
*
input_broadcasted_gm
,
const
_size_t
*
output_shape_gm
,
const
_size_t
*
input_shapes_gm
,
const
_ptrdiff_t
*
output_strides_gm
,
const
_ptrdiff_t
*
input_strides_gm
,
Tdata
*
output
,
const
void
*
const
*
inputs
,
Args
...
args
)
{
int
cid
=
core_id
();
int
ncores
=
core_num
();
if
(
cid
>=
ncores
)
{
return
;
}
int
thread_id
=
ncores
*
cluster_id
()
+
cid
;
int
nthreads
=
ncores
*
cluster_num
();
// Cast input gm pointer type
auto
typed_inputs
=
reinterpret_cast
<
const
__global_ptr__
Tdata
*
const
__global_ptr__
*>
(
inputs
);
const
int
BUFF_SIZE
=
64
;
// Input data cache
__local__
Tdata
inputs_buf
[
N
];
// Input contiguous/broadcasted flags
__local__
bool
input_contiguous
[
N
];
__local__
bool
input_broadcasted
[
N
];
// Input shape/strides
__local__
_size_t
input_shapes
[
N
*
ndim
];
__local__
_ptrdiff_t
input_strides
[
N
*
ndim
];
// Output shape/strides
__local__
_size_t
output_shape
[
ndim
];
__local__
_ptrdiff_t
output_strides
[
ndim
];
// Inputs gm ptr buf
__local__
__global_ptr__
Tdata
*
typed_inputs_ptr
[
N
];
// Load from gm
GM2LM_ASYNC
(
input_contiguous_gm
,
input_contiguous
,
N
*
sizeof
(
bool
));
GM2LM_ASYNC
(
input_broadcasted_gm
,
input_broadcasted
,
N
*
sizeof
(
bool
));
GM2LM_ASYNC
(
input_shapes_gm
,
input_shapes
,
N
*
ndim
*
sizeof
(
_size_t
));
GM2LM_ASYNC
(
input_strides_gm
,
input_strides
,
N
*
ndim
*
sizeof
(
_ptrdiff_t
));
GM2LM_ASYNC
(
output_shape_gm
,
output_shape
,
ndim
*
sizeof
(
_size_t
));
GM2LM_ASYNC
(
output_strides_gm
,
output_strides
,
ndim
*
sizeof
(
_ptrdiff_t
));
GM2LM_ASYNC
(
typed_inputs
,
typed_inputs_ptr
,
N
*
sizeof
(
__global_ptr__
Tdata
*
));
mfence
();
int
len_per_loop
=
min
(
BUFF_SIZE
,
roundup_div
(
output_size
,
nthreads
));
for
(
int
start
=
thread_id
*
len_per_loop
;
start
<
output_size
;
start
+=
nthreads
*
len_per_loop
)
{
size_t
read_len
=
min
(
len_per_loop
,
output_size
-
start
);
for
(
int
idx
=
start
;
idx
<
start
+
read_len
;
++
idx
)
{
size_t
out_idx
=
getOutputIndex
(
static_cast
<
size_t
>
(
idx
),
output_contiguous
,
ndim
,
output_shape
,
output_strides
);
InputIndexer
indexer
{
static_cast
<
size_t
>
(
idx
),
ndim
,
input_contiguous
,
input_broadcasted
,
input_shapes
,
input_strides
,
output_strides
};
// Get index offset for every operand
size_t
indexes
[
N
];
for
(
size_t
i
=
0
;
i
<
N
;
i
++
)
{
indexes
[
i
]
=
indexer
(
i
);
}
// Launch operater
launchOp
<
N
,
Op
,
Tdata
>
(
&
typed_inputs_ptr
[
0
],
output
,
inputs_buf
,
indexes
,
out_idx
,
args
...);
}
}
sync_cluster
();
}
#define LAUNCH_ELEMENTWISE_KERNEL_IMPL(OpName, Op) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args) { \
elementwiseKernel<Op::num_inputs, Op, Tdata><<<8, 64, stream>>>( \
output_size, ndim, output_contiguous, \
reinterpret_cast<const bool *>(input_contiguous), \
reinterpret_cast<const bool *>(input_broadcasted), \
reinterpret_cast<const _size_t *>(output_shape), \
reinterpret_cast<const _size_t *>(input_shapes), \
reinterpret_cast<const _ptrdiff_t *>(output_strides), \
reinterpret_cast<const _ptrdiff_t *>(input_strides), \
reinterpret_cast<Tdata *>(output), inputs, args...); \
}
#define LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(OpName, T, ...) \
template void launch##OpName##Kernel<T, ##__VA_ARGS__>( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
##__VA_ARGS__);
#endif
src/infiniop/ops/add/cpu/add_cpu.cc
0 → 100644
View file @
c2e87202
#include "add_cpu.h"
namespace
op
::
add
::
cpu
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
cpu
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
a_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
b_desc
=
input_desc_vec
.
at
(
1
);
const
auto
&
c_shape
=
out_desc
->
shape
();
const
auto
&
a_shape
=
a_desc
->
shape
();
const
auto
&
b_shape
=
b_desc
->
shape
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
);
CHECK_SAME_SHAPE
(
c_shape
,
a_shape
,
b_shape
);
// create CPU elementwise descriptor
CREATE_ELEMENTWISE_CPU_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
AddOp
,
fp16_t
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
AddOp
,
float
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
AddOp
,
double
>
(
_info
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::add::cpu
src/infiniop/ops/add/cpu/add_cpu.h
0 → 100644
View file @
c2e87202
#ifndef __ADD_CPU_H__
#define __ADD_CPU_H__
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR
(
add
,
cpu
)
namespace
op
::
add
::
cpu
{
typedef
struct
AddOp
{
public:
static
constexpr
size_t
num_inputs
=
2
;
template
<
typename
T
>
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
+
b
;
}
}
AddOp
;
}
// namespace op::add::cpu
#endif // __ADD_CPU_H__
src/infiniop/ops/add/cuda/add_cuda.cu
0 → 100644
View file @
c2e87202
#include "add_cuda.cuh"
#include "add_cuda_internal.cuh"
namespace
op
::
add
::
cuda
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
cuda
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
a_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
b_desc
=
input_desc_vec
.
at
(
1
);
const
auto
&
c_shape
=
out_desc
->
shape
();
const
auto
&
a_shape
=
a_desc
->
shape
();
const
auto
&
b_shape
=
b_desc
->
shape
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
);
CHECK_SAME_SHAPE
(
c_shape
,
a_shape
,
b_shape
);
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
)
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
256
,
AddOp
,
half
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
256
,
AddOp
,
float
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
256
,
AddOp
,
double
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::add::cuda
src/infiniop/ops/add/cuda/add_cuda.cuh
0 → 100644
View file @
c2e87202
#ifndef __ADD_CUDA_API_H__
#define __ADD_CUDA_API_H__
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
ELEMENTWISE_DESCRIPTOR
(
add
,
cuda
)
#endif // __ADD_CUDA_API_H__
src/infiniop/ops/add/cuda/add_cuda_internal.cuh
0 → 100644
View file @
c2e87202
#ifndef __ADD_CUDA_H__
#define __ADD_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_fp16.h>
namespace
op
::
add
::
cuda
{
typedef
struct
AddOp
{
public:
static
constexpr
size_t
num_inputs
=
2
;
template
<
typename
T
>
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
return
__hadd2
(
a
,
b
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
return
__hadd
(
a
,
b
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
__fadd_rd
(
a
,
b
);
}
else
{
return
a
+
b
;
}
}
}
AddOp
;
}
// namespace op::add::cuda
#endif // __ADD_CUDA_H__
src/infiniop/ops/add/operator.cc
0 → 100644
View file @
c2e87202
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/add.h"
#ifdef ENABLE_CPU_API
#include "cpu/add_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/add_cuda.cuh"
#endif
__C
infiniStatus_t
infiniopCreateAddDescriptor
(
infiniopHandle_t
handle
,
infiniopAddDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::add::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::add::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
{a_desc, \
b_desc})
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_CUDA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetAddWorkspaceSize
(
infiniopAddDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::add::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#ifdef ENABLE_CUDA_API
GET
(
INFINI_DEVICE_NVIDIA
,
cuda
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopAdd
(
infiniopAddDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
const
void
*
a
,
const
void
*
b
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, c, {a, b}, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_CUDA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyAddDescriptor
(
infiniopAddDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::add::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DELETE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_CUDA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
src/infiniop/ops/attention/attention.h
0 → 100644
View file @
c2e87202
#ifndef ATTENTION_H
#define ATTENTION_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::attention::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
}; \
}
#endif // ATTENTION_H
src/infiniop/ops/attention/operator.cc
0 → 100644
View file @
c2e87202
#include "../../operator.h"
#include "../../../utils.h"
#include "../../../utils/check.h"
#include "../../handle.h"
#include "../../tensor.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/rearrange.h"
#include <cmath>
#include <cstdint>
struct
InfiniopAttentionDescriptor
{
InfiniopDescriptor
_super
;
infiniopRearrangeDescriptor_t
rearrange_desc_k
;
infiniopRearrangeDescriptor_t
rearrange_desc_v
;
infiniopRearrangeDescriptor_t
rearrange_desc_q
;
infiniopRearrangeDescriptor_t
rearrange_desc_out
;
infiniopGemmDescriptor_t
matmul_desc1
;
infiniopGemmDescriptor_t
matmul_desc2
;
infiniopCausalSoftmaxDescriptor_t
softmax_desc
;
size_t
workspace_size
;
size_t
op_workspace_offset
;
size_t
op_workspace_size
;
size_t
q_cont_offset
;
size_t
att_score_offset
;
size_t
att_val_offset
;
size_t
k_cache_offset
;
size_t
v_cache_offset
;
float
qk_alpha
;
};
__C
__export
infiniStatus_t
infiniopCreateAttentionDescriptor
(
infiniopHandle_t
handle
,
infiniopAttentionDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
size_t
pos
)
{
if
(
out_desc
->
ndim
()
!=
3
||
q_desc
->
ndim
()
!=
3
||
k_desc
->
ndim
()
!=
3
||
v_desc
->
ndim
()
!=
3
||
k_cache_desc
->
ndim
()
!=
3
||
v_cache_desc
->
ndim
()
!=
3
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
!
out_desc
->
isContiguous
(
0
,
2
))
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
if
(
q_desc
->
strides
()[
2
]
!=
1
||
k_desc
->
strides
()[
2
]
!=
1
||
v_desc
->
strides
()[
2
]
!=
1
||
k_cache_desc
->
strides
()[
2
]
!=
1
||
v_cache_desc
->
strides
()[
2
]
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
size_t
n_q_head
=
q_desc
->
shape
()[
0
];
size_t
seq_len
=
q_desc
->
shape
()[
1
];
size_t
head_dim
=
q_desc
->
shape
()[
2
];
size_t
hidden_size
=
n_q_head
*
head_dim
;
size_t
n_kv_head
=
k_desc
->
shape
()[
0
];
size_t
total_seq_len
=
seq_len
+
pos
;
size_t
n_group
=
n_q_head
/
n_kv_head
;
size_t
alignment
=
256
;
if
(
out_desc
->
shape
()[
0
]
!=
seq_len
||
out_desc
->
shape
()[
1
]
!=
n_q_head
||
out_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// k: [n_kv_head, seq_len, head_dim]
if
(
k_desc
->
shape
()[
0
]
!=
n_kv_head
||
k_desc
->
shape
()[
1
]
!=
seq_len
||
k_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// v: [n_kv_head, seq_len, head_dim]
if
(
v_desc
->
shape
()[
0
]
!=
n_kv_head
||
v_desc
->
shape
()[
1
]
!=
seq_len
||
v_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// k_cache: [n_kv_head, _, head_dim]
if
(
k_cache_desc
->
shape
()[
0
]
!=
n_kv_head
||
k_cache_desc
->
shape
()[
1
]
<
total_seq_len
||
k_cache_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// v_cache: [n_kv_head, _, head_dim]
if
(
v_cache_desc
->
shape
()[
0
]
!=
n_kv_head
||
v_cache_desc
->
shape
()[
1
]
<
total_seq_len
||
v_cache_desc
->
shape
()[
2
]
!=
head_dim
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
// Rearrange k into k_cache
infiniopTensorDescriptor_t
dst_k_desc
;
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
dst_k_desc
,
3
,
k_desc
->
shape
().
data
(),
k_cache_desc
->
strides
().
data
(),
k_cache_desc
->
dtype
()));
infiniopRearrangeDescriptor_t
rearrange_desc_k
;
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_k
,
dst_k_desc
,
k_desc
));
// Rearrange v into v_cache
infiniopTensorDescriptor_t
dst_v_desc
;
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
dst_v_desc
,
3
,
v_desc
->
shape
().
data
(),
v_cache_desc
->
strides
().
data
(),
v_cache_desc
->
dtype
()));
infiniopRearrangeDescriptor_t
rearrange_desc_v
;
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_v
,
dst_v_desc
,
v_desc
));
infiniopRearrangeDescriptor_t
rearrange_desc_q
=
nullptr
;
size_t
q_cont_size
=
0
;
infiniopTensorDescriptor_t
rearranged_q_desc
;
// Rearrange q into contiguous
if
(
!
q_desc
->
isContiguous
(
0
,
1
))
{
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
rearranged_q_desc
,
3
,
q_desc
->
shape
().
data
(),
nullptr
,
q_desc
->
dtype
()));
q_cont_size
=
utils
::
align
(
rearranged_q_desc
->
numel
()
*
infiniSizeOf
(
rearranged_q_desc
->
dtype
()),
alignment
);
rearrange_desc_q
=
new
InfiniopDescriptor
;
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_q
,
rearranged_q_desc
,
q_desc
));
}
// Matmul1: q * full_k
// q: [n_q_head, seq_len, head_dim] -> [n_kv_head, n_group *seq_len, head_dim]
infiniopTensorDescriptor_t
reshaped_q_desc
;
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
reshaped_q_desc
,
3
,
q_desc
->
shape
().
data
(),
nullptr
,
q_desc
->
dtype
()));
TRANSFORM_TENSOR_DESC
(
reshaped_q_desc
,
dimSplit
(
0
,
{
n_kv_head
,
n_group
}));
TRANSFORM_TENSOR_DESC
(
reshaped_q_desc
,
dimMerge
(
1
,
2
));
// full_k: [n_kv_head, head_dim, total_seq_len]
infiniopTensorDescriptor_t
full_k_desc
;
size_t
full_k_shape
[
3
]
=
{
n_kv_head
,
total_seq_len
,
head_dim
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
full_k_desc
,
3
,
full_k_shape
,
k_cache_desc
->
strides
().
data
(),
k_cache_desc
->
dtype
()));
TRANSFORM_TENSOR_DESC
(
full_k_desc
,
dimPermute
({
0
,
2
,
1
}));
// qk: [n_kv_head, n_group * seq_len, total_seq_len]
infiniopTensorDescriptor_t
qk_desc
;
size_t
qk_shape
[
3
]
=
{
n_kv_head
,
n_group
*
seq_len
,
total_seq_len
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
qk_desc
,
3
,
qk_shape
,
nullptr
,
q_desc
->
dtype
()));
// matmul1_desc
// qk_alpha
float
qk_alpha
=
1
/
sqrt
(
head_dim
);
infiniopGemmDescriptor_t
matmul1_desc
;
CHECK_STATUS
(
infiniopCreateGemmDescriptor
(
handle
,
&
matmul1_desc
,
qk_desc
,
reshaped_q_desc
,
full_k_desc
));
// matmul1 workspace size
size_t
matmul1_workspace_size
;
CHECK_STATUS
(
infiniopGetGemmWorkspaceSize
(
matmul1_desc
,
&
matmul1_workspace_size
));
matmul1_workspace_size
=
utils
::
align
(
matmul1_workspace_size
,
alignment
);
// attention score tensor size
size_t
attn_score_size
=
utils
::
align
(
qk_desc
->
numel
()
*
infiniSizeOf
(
qk_desc
->
dtype
()),
alignment
);
// CausalSoftmax: softmax(qk)
// qk: [n_kv_head, n_group * seq_len, total_seq_len] -> [n_q_head, seq_len, total_seq_len]
TRANSFORM_TENSOR_DESC
(
qk_desc
,
dimSplit
(
1
,
{
n_group
,
seq_len
}));
TRANSFORM_TENSOR_DESC
(
qk_desc
,
dimMerge
(
0
,
1
));
infiniopCausalSoftmaxDescriptor_t
softmax_desc
;
CHECK_STATUS
(
infiniopCreateCausalSoftmaxDescriptor
(
handle
,
&
softmax_desc
,
qk_desc
,
qk_desc
));
// softmax workspace size
size_t
softmax_workspace_size
;
CHECK_STATUS
(
infiniopGetCausalSoftmaxWorkspaceSize
(
softmax_desc
,
&
softmax_workspace_size
));
softmax_workspace_size
=
utils
::
align
(
softmax_workspace_size
,
alignment
);
// Matmul2: softmax(qk) * full_v
// softmax(qk): [n_q_head, seq_len, total_seq_len] -> [n_kv_head, n_group * seq_len, total_seq_len]
// full_v: [n_kv_head, total_seq_len, head_dim]
TRANSFORM_TENSOR_DESC
(
qk_desc
,
dimSplit
(
0
,
{
n_kv_head
,
n_group
}));
TRANSFORM_TENSOR_DESC
(
qk_desc
,
dimMerge
(
1
,
2
));
infiniopTensorDescriptor_t
full_v_desc
;
size_t
full_v_shape
[
3
]
=
{
n_kv_head
,
total_seq_len
,
head_dim
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
full_v_desc
,
3
,
full_v_shape
,
v_cache_desc
->
strides
().
data
(),
v_cache_desc
->
dtype
()));
// temp_out: [n_kv_head, n_group * seq_len, head_dim]
infiniopTensorDescriptor_t
att_val_desc
;
size_t
temp_out_shape
[
3
]
=
{
n_kv_head
,
n_group
*
seq_len
,
head_dim
};
CHECK_STATUS
(
infiniopCreateTensorDescriptor
(
&
att_val_desc
,
3
,
temp_out_shape
,
nullptr
,
q_desc
->
dtype
()));
// matmul2_desc
infiniopGemmDescriptor_t
matmul2_desc
;
CHECK_STATUS
(
infiniopCreateGemmDescriptor
(
handle
,
&
matmul2_desc
,
att_val_desc
,
qk_desc
,
full_v_desc
));
// matmul2 workspace size
size_t
matmul2_workspace_size
;
CHECK_STATUS
(
infiniopGetGemmWorkspaceSize
(
matmul2_desc
,
&
matmul2_workspace_size
));
matmul2_workspace_size
=
utils
::
align
(
matmul2_workspace_size
,
alignment
);
// attention value tensor size
size_t
att_val_size
=
utils
::
align
(
att_val_desc
->
numel
()
*
infiniSizeOf
(
att_val_desc
->
dtype
()),
alignment
);
// Rearrange temp_out into out
// out: [seq_len, n_q_head, head_dim]
// temp_out: [n_kv_head, n_group * seq_len, head_dim] -> [n_q_head, seq_len, head_dim] -> [seq_len, n_q_head, head_dim]
TRANSFORM_TENSOR_DESC
(
att_val_desc
,
dimSplit
(
1
,
{
n_group
,
seq_len
}));
TRANSFORM_TENSOR_DESC
(
att_val_desc
,
dimMerge
(
0
,
1
));
TRANSFORM_TENSOR_DESC
(
att_val_desc
,
dimPermute
({
1
,
0
,
2
}));
infiniopRearrangeDescriptor_t
rearrange_desc_out
;
CHECK_STATUS
(
infiniopCreateRearrangeDescriptor
(
handle
,
&
rearrange_desc_out
,
out_desc
,
att_val_desc
));
// workspace size
size_t
op_workspace_size
=
utils
::
align
(
std
::
max
(
std
::
max
(
matmul1_workspace_size
,
matmul2_workspace_size
),
softmax_workspace_size
),
alignment
);
size_t
temp_tensors_size
=
attn_score_size
+
std
::
max
(
q_cont_size
,
att_val_size
);
size_t
workspace_size
=
temp_tensors_size
+
op_workspace_size
;
// k_cache_offset
size_t
k_cache_offset
=
0
;
if
(
pos
>
0
)
{
k_cache_offset
=
pos
*
k_cache_desc
->
getByteStrides
()[
1
];
}
// v_cache_offset
size_t
v_cache_offset
=
0
;
if
(
pos
>
0
)
{
v_cache_offset
=
pos
*
v_cache_desc
->
getByteStrides
()[
1
];
}
// create attention descriptor
*
(
InfiniopAttentionDescriptor
**
)
desc_ptr
=
new
InfiniopAttentionDescriptor
{
{
handle
->
device
,
handle
->
device_id
},
rearrange_desc_k
,
rearrange_desc_v
,
rearrange_desc_q
,
rearrange_desc_out
,
matmul1_desc
,
matmul2_desc
,
softmax_desc
,
workspace_size
,
temp_tensors_size
,
op_workspace_size
,
attn_score_size
,
0
,
attn_score_size
,
k_cache_offset
,
v_cache_offset
,
1.
f
/
std
::
sqrt
(
float
(
head_dim
)),
};
return
INFINI_STATUS_SUCCESS
;
}
__C
__export
infiniStatus_t
infiniopGetAttentionWorkspaceSize
(
infiniopAttentionDescriptor_t
desc
,
size_t
*
size
)
{
*
size
=
((
InfiniopAttentionDescriptor
*
)
desc
)
->
workspace_size
;
return
INFINI_STATUS_SUCCESS
;
}
__C
__export
infiniStatus_t
infiniopAttention
(
infiniopAttentionDescriptor_t
desc_
,
void
*
workspace_
,
size_t
workspace_size_
,
void
*
out
,
void
const
*
q
,
void
const
*
k
,
void
const
*
v
,
void
*
k_cache
,
void
*
v_cache
,
void
*
stream
)
{
auto
desc
=
(
InfiniopAttentionDescriptor
*
)
desc_
;
if
(
workspace_size_
<
desc
->
workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
// STATUS_MEMORY_NOT_ALLOCATED
}
void
*
workspace
=
(
char
*
)
workspace_
+
desc
->
op_workspace_offset
;
size_t
workspace_size
=
desc
->
op_workspace_size
;
void
*
att_score
=
(
char
*
)
workspace_
+
desc
->
att_score_offset
;
void
*
att_val
=
(
char
*
)
workspace_
+
desc
->
att_val_offset
;
void
const
*
q_
=
q
;
// concat k and v to k_cache and v_cache
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_k
,
(
char
*
)
k_cache
+
desc
->
k_cache_offset
,
k
,
stream
));
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_v
,
(
char
*
)
v_cache
+
desc
->
v_cache_offset
,
v
,
stream
));
// rearrange q into contiguous
if
(
desc
->
rearrange_desc_q
)
{
void
*
q_cont
=
(
char
*
)
workspace_
+
desc
->
q_cont_offset
;
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_q
,
q_cont
,
q
,
stream
));
q_
=
q_cont
;
}
// matmul1: q * full_k
CHECK_STATUS
(
infiniopGemm
(
desc
->
matmul_desc1
,
workspace
,
workspace_size
,
att_score
,
q_
,
k_cache
,
desc
->
qk_alpha
,
0.0
,
stream
));
// softmax(qk)
CHECK_STATUS
(
infiniopCausalSoftmax
(
desc
->
softmax_desc
,
workspace
,
workspace_size
,
att_score
,
att_score
,
stream
));
// matmul2: softmax(qk) * full_v
CHECK_STATUS
(
infiniopGemm
(
desc
->
matmul_desc2
,
workspace
,
workspace_size
,
att_val
,
att_score
,
v_cache
,
1.0
,
0.0
,
stream
));
// rearrange out
CHECK_STATUS
(
infiniopRearrange
(
desc
->
rearrange_desc_out
,
out
,
att_val
,
stream
));
return
INFINI_STATUS_SUCCESS
;
}
__C
__export
infiniStatus_t
infiniopDestroyAttentionDescriptor
(
infiniopAttentionDescriptor_t
desc_
)
{
auto
desc
=
(
InfiniopAttentionDescriptor
*
)
desc_
;
if
(
desc
->
rearrange_desc_q
)
{
CHECK_STATUS
(
infiniopDestroyRearrangeDescriptor
(
desc
->
rearrange_desc_q
));
}
CHECK_STATUS
(
infiniopDestroyRearrangeDescriptor
(
desc
->
rearrange_desc_k
));
CHECK_STATUS
(
infiniopDestroyRearrangeDescriptor
(
desc
->
rearrange_desc_v
));
CHECK_STATUS
(
infiniopDestroyRearrangeDescriptor
(
desc
->
rearrange_desc_out
));
CHECK_STATUS
(
infiniopDestroyGemmDescriptor
(
desc
->
matmul_desc1
));
CHECK_STATUS
(
infiniopDestroyGemmDescriptor
(
desc
->
matmul_desc2
));
CHECK_STATUS
(
infiniopDestroyCausalSoftmaxDescriptor
(
desc
->
softmax_desc
));
delete
desc
;
return
INFINI_STATUS_SUCCESS
;
}
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment