Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
adbda4c4
Unverified
Commit
adbda4c4
authored
Aug 13, 2025
by
thatPepe
Committed by
GitHub
Aug 13, 2025
Browse files
issue/214 - Elementwise Support for Cambricon Bang
parent
0f250ec4
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1122 additions
and
41 deletions
+1122
-41
src/infiniop/devices/bang/bang_handle.cc
src/infiniop/devices/bang/bang_handle.cc
+9
-1
src/infiniop/devices/bang/bang_kernel_common.h
src/infiniop/devices/bang/bang_kernel_common.h
+229
-0
src/infiniop/devices/bang/common_bang.h
src/infiniop/devices/bang/common_bang.h
+12
-0
src/infiniop/elementwise/bang/elementwise_bang.h
src/infiniop/elementwise/bang/elementwise_bang.h
+212
-0
src/infiniop/elementwise/bang/elementwise_bang_api.h
src/infiniop/elementwise/bang/elementwise_bang_api.h
+84
-0
src/infiniop/elementwise/bang/elementwise_bang_kernel.h
src/infiniop/elementwise/bang/elementwise_bang_kernel.h
+331
-0
src/infiniop/ops/add/bang/add_bang.h
src/infiniop/ops/add/bang/add_bang.h
+8
-0
src/infiniop/ops/add/bang/add_bang.mlu
src/infiniop/ops/add/bang/add_bang.mlu
+68
-0
src/infiniop/ops/add/bang/add_bang_internal.mlu
src/infiniop/ops/add/bang/add_bang_internal.mlu
+25
-0
src/infiniop/ops/add/operator.cc
src/infiniop/ops/add/operator.cc
+15
-0
src/infiniop/ops/swiglu/bang/swiglu_bang.h
src/infiniop/ops/swiglu/bang/swiglu_bang.h
+8
-0
src/infiniop/ops/swiglu/bang/swiglu_bang.mlu
src/infiniop/ops/swiglu/bang/swiglu_bang.mlu
+68
-0
src/infiniop/ops/swiglu/bang/swiglu_bang_internal.mlu
src/infiniop/ops/swiglu/bang/swiglu_bang_internal.mlu
+37
-0
src/infiniop/ops/swiglu/operator.cc
src/infiniop/ops/swiglu/operator.cc
+11
-33
test/infiniop/libinfiniop/utils.py
test/infiniop/libinfiniop/utils.py
+2
-4
test/infiniop/swiglu.py
test/infiniop/swiglu.py
+2
-2
xmake/bang.lua
xmake/bang.lua
+1
-1
No files found.
src/infiniop/devices/bang/bang_handle.cc
View file @
adbda4c4
...
...
@@ -8,12 +8,17 @@ namespace device::bang {
Handle
::
Handle
(
infiniDevice_t
device
,
int
device_id
)
:
InfiniopHandle
{
device
,
device_id
},
_internal
(
std
::
make_shared
<
Handle
::
Internal
>
())
{}
_internal
(
std
::
make_shared
<
Handle
::
Internal
>
(
device_id
))
{}
auto
Handle
::
internal
()
const
->
const
std
::
shared_ptr
<
Internal
>
&
{
return
_internal
;
}
Handle
::
Internal
::
Internal
(
int
device_id
)
{
cnrtDeviceGetAttribute
(
&
_cluster_count
,
cnrtAttrClusterCount
,
device_id
);
cnrtDeviceGetAttribute
(
&
_core_per_cluster
,
cnrtAttrMcorePerCluster
,
device_id
);
}
infiniStatus_t
Handle
::
Internal
::
useCnnl
(
cnrtQueue_t
queue
,
const
Fn
<
cnnlHandle_t
>
&
f
)
const
{
auto
handle
=
cnnl_handles
.
pop
();
if
(
!
handle
)
{
...
...
@@ -25,6 +30,9 @@ infiniStatus_t Handle::Internal::useCnnl(cnrtQueue_t queue, const Fn<cnnlHandle_
return
INFINI_STATUS_SUCCESS
;
}
int
Handle
::
Internal
::
getCorePerCluster
()
const
{
return
_core_per_cluster
;
}
int
Handle
::
Internal
::
getClusterCount
()
const
{
return
_cluster_count
;
}
cnnlDataType_t
getCnnlDtype
(
infiniDtype_t
dt
)
{
switch
(
dt
)
{
case
INFINI_DTYPE_F32
:
...
...
src/infiniop/devices/bang/bang_kernel_common.h
0 → 100644
View file @
adbda4c4
#ifndef __INFINIOP_BANG_KERNEL_COMMON_H__
#define __INFINIOP_BANG_KERNEL_COMMON_H__
// Include Cambricon CNNL and CNRT headers for MLU (Machine Learning Unit) specific functions
#include "cnnl.h"
#include "cnrt.h"
namespace
device
::
bang
::
kernel
{
/**
* @brief Converts a flattened index to a reduced offset considering broadcasting.
*
* This function is used when dealing with broadcasted tensors where the input
* has been broadcast to match the output shape. It calculates the offset in
* the original (non-broadcasted) tensor.
*
* @param flat_index The flattened index in the output tensor
* @param ndim Number of dimensions
* @param broadcasted_strides Strides of the broadcasted tensor
* @param target_strides Strides of the original (non-broadcasted) tensor
* @return size_t Offset in the original tensor's memory
*/
inline
__mlu_device__
size_t
indexToReducedOffset
(
size_t
flat_index
,
size_t
ndim
,
const
ptrdiff_t
*
broadcasted_strides
,
const
ptrdiff_t
*
target_strides
)
{
size_t
res
=
0
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
// Calculate contribution from each dimension
res
+=
flat_index
/
broadcasted_strides
[
i
]
*
target_strides
[
i
];
// Remove the contribution from this dimension
flat_index
%=
broadcasted_strides
[
i
];
}
return
res
;
}
/**
* @brief Converts a flattened index to a memory offset considering tensor striding.
*
* This is the general case for non-contiguous tensors where elements are not
* stored sequentially in memory.
*
* @param flat_index The flattened index in the tensor
* @param ndim Number of dimensions
* @param shape Tensor shape
* @param strides Tensor strides (in elements)
* @return size_t Offset in the tensor's memory
*/
inline
__mlu_device__
size_t
indexToOffset
(
size_t
flat_index
,
size_t
ndim
,
const
size_t
*
shape
,
const
ptrdiff_t
*
strides
)
{
size_t
res
=
0
;
// Process dimensions from highest to lowest
for
(
size_t
i
=
ndim
;
i
--
>
0
;)
{
// Add contribution from this dimension
res
+=
(
flat_index
%
shape
[
i
])
*
strides
[
i
];
// Remove the contribution from this dimension
flat_index
/=
shape
[
i
];
}
return
res
;
}
/**
* @brief Helper struct for computing input tensor indices considering broadcasting and striding.
*
* This is particularly useful for operations where inputs may be broadcasted
* to match the output shape, or may have non-contiguous memory layouts.
*/
struct
InputIndexer
{
size_t
idx
;
// Base index for this task
size_t
ndim
;
// Number of dimensions
const
bool
*
input_contiguous
;
// Array indicating which inputs are contiguous
const
bool
*
input_broadcasted
;
// Array indicating which inputs are broadcasted
const
size_t
*
input_shapes
;
// Array of input shapes (concatenated)
const
ptrdiff_t
*
input_strides
;
// Array of input strides (concatenated)
const
ptrdiff_t
*
output_strides
;
// Output tensor strides
/**
* @brief Computes memory offset for input tensor element.
*
* @param input_id Input tensor ID.
* @param element_idx Element index in output tensor.
* @return size_t Memory offset in input tensor.
*/
__mlu_device__
size_t
operator
()(
size_t
input_id
,
size_t
element_idx
)
const
{
size_t
global_idx
=
idx
+
element_idx
;
return
input_contiguous
[
input_id
]
?
global_idx
// Simple case: contiguous memory
:
(
input_broadcasted
[
input_id
]
// Handle broadcasted case
?
indexToReducedOffset
(
global_idx
,
ndim
,
output_strides
,
input_strides
+
input_id
*
ndim
)
// General non-contiguous case
:
indexToOffset
(
global_idx
,
ndim
,
input_shapes
+
input_id
*
ndim
,
input_strides
+
input_id
*
ndim
));
}
};
/**
* @brief Computes output tensor index considering striding.
*
* @param idx Linear index.
* @param is_contiguous Whether output is contiguous.
* @param ndim Number of dimensions.
* @param shape Output tensor shape.
* @param strides Output tensor strides.
* @return size_t Memory offset in output tensor.
*/
inline
__mlu_device__
size_t
getOutputIndex
(
size_t
idx
,
bool
is_contiguous
,
size_t
ndim
,
const
size_t
*
shape
,
const
ptrdiff_t
*
strides
)
{
return
is_contiguous
?
idx
:
indexToOffset
(
idx
,
ndim
,
shape
,
strides
);
}
/**
* @brief Calculates optimal chunk size for memory operations based on tensor contiguity.
*
* This function doesn't handle tensors with non-standard strides, which
* require more general optimizations not specific to Cambricon.
*
* @param global_idx_ Starting global index.
* @param ndim Number of dimensions.
* @param shape Tensor shape.
* @param strides Tensor strides.
* @param max_len Maximum allowed chunk size.
* @return size_t Optimal chunk size for memory operations.
*/
__mlu_device__
size_t
calculateChunkSize
(
size_t
global_idx_
,
size_t
ndim
,
const
size_t
*
shape
,
const
ptrdiff_t
*
strides
,
size_t
max_len
)
{
// Find the last dimension that is contiguous
int
last_contiguous_dim
=
-
1
;
ptrdiff_t
expected_stride
=
1
;
for
(
int
i
=
(
int
)
ndim
-
1
;
i
>=
0
;
--
i
)
{
if
(
strides
[
i
]
!=
expected_stride
)
{
break
;
}
last_contiguous_dim
=
i
;
if
(
i
>
0
)
{
expected_stride
*=
shape
[
i
];
}
}
if
(
last_contiguous_dim
<
0
)
{
return
1
;
}
// Calculate position in the contiguous block
size_t
global_idx
=
global_idx_
;
size_t
pos_in_block
=
0
;
size_t
block_size
=
1
;
for
(
int
i
=
(
int
)
ndim
-
1
;
i
>=
last_contiguous_dim
;
--
i
)
{
size_t
dim_idx
=
global_idx
%
shape
[
i
];
pos_in_block
+=
dim_idx
*
block_size
;
block_size
*=
shape
[
i
];
global_idx
/=
shape
[
i
];
}
size_t
remaining_in_block
=
block_size
-
pos_in_block
;
return
std
::
min
(
max_len
,
remaining_in_block
);
}
/**
* @brief Helper function for non-contiguous memory copy
*
* @param dst Destination buffer
* @param src Source buffer
* @param direction Memory copy direction (GDRAM2NRAM or NRAM2GDRAM)
* @param indexer Input indexer helper (for input copies)
* @param input_idx Input tensor index (for input copies)
* @param processed Number of elements already processed
* @param curr_batch Current batch size
* @param start_idx Starting index for this task
* @param ndim Number of dimensions
* @param shape Tensor shape
* @param strides Tensor strides
* @param is_input_copy Whether this is an input copy operation
*/
template
<
typename
Tdata
>
__mlu_device__
void
nonContiguousMemcpy
(
Tdata
*
dst
,
Tdata
*
src
,
mluMemcpyDirection_t
direction
,
InputIndexer
&
indexer
,
size_t
input_idx
,
size_t
processed
,
size_t
curr_batch
,
size_t
start_idx
,
size_t
ndim
,
const
size_t
*
shape
,
const
ptrdiff_t
*
strides
,
bool
is_input_copy
)
{
size_t
remaining
=
curr_batch
;
size_t
current_pos
=
0
;
while
(
remaining
>
0
)
{
size_t
element_offset
=
is_input_copy
?
indexer
(
input_idx
,
processed
+
current_pos
)
:
getOutputIndex
(
start_idx
+
processed
+
current_pos
,
false
,
// output_contiguous is false for non-contiguous
ndim
,
shape
,
strides
);
size_t
chunk_size
=
calculateChunkSize
(
start_idx
+
processed
+
current_pos
,
ndim
,
shape
,
strides
,
remaining
);
__memcpy_async
(
dst
+
(
is_input_copy
?
current_pos
:
element_offset
),
src
+
(
is_input_copy
?
element_offset
:
current_pos
),
chunk_size
*
sizeof
(
Tdata
),
direction
);
current_pos
+=
chunk_size
;
remaining
-=
chunk_size
;
}
}
}
// namespace device::bang::kernel
#endif
src/infiniop/devices/bang/common_bang.h
View file @
adbda4c4
...
...
@@ -2,6 +2,7 @@
#define __COMMON_BANG_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include "../pool.h"
#include "bang_handle.h"
#include "cnnl.h"
...
...
@@ -10,16 +11,27 @@
#define CHECK_BANG(API) CHECK_INTERNAL(API, CNNL_STATUS_SUCCESS)
#define NRAM_MAX_SIZE 1024 * 240
constexpr
size_t
ALIGN_SIZE
=
128
;
namespace
device
::
bang
{
class
Handle
::
Internal
{
Pool
<
cnnlHandle_t
>
cnnl_handles
;
int
_core_per_cluster
;
int
_cluster_count
;
template
<
typename
T
>
using
Fn
=
std
::
function
<
infiniStatus_t
(
T
)
>
;
public:
Internal
(
int
);
infiniStatus_t
useCnnl
(
cnrtQueue_t
queue
,
const
Fn
<
cnnlHandle_t
>
&
f
)
const
;
int
getCorePerCluster
()
const
;
int
getClusterCount
()
const
;
};
cnnlDataType_t
getCnnlDtype
(
infiniDtype_t
dt
);
...
...
src/infiniop/elementwise/bang/elementwise_bang.h
0 → 100644
View file @
adbda4c4
#ifndef __INFINIOP_ELEMENTWISE_BANG_H__
#define __INFINIOP_ELEMENTWISE_BANG_H__
#include "../../../utils.h"
#include "../../devices/bang/common_bang.h"
#include "elementwise_bang_api.h"
namespace
op
::
elementwise
::
bang
{
/**
* @brief Opaque implementation structure for BANG device operations.
*
* Contains device-specific resources and implementation methods.
*/
struct
DeviceImpl
::
Opaque
{
std
::
shared_ptr
<
device
::
bang
::
Handle
::
Internal
>
internal
;
/**
* @brief Constructs an Opaque instance with device handle internals.
*
* @param internal_ Shared pointer to BANG device handle internals.
*/
Opaque
(
const
std
::
shared_ptr
<
device
::
bang
::
Handle
::
Internal
>
&
internal_
)
:
internal
(
internal_
)
{}
/**
* @brief Implements elementwise calculation for BANG device.
*
* @tparam N Number of input tensors.
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for the operator.
*
* @param info Elementwise operation metadata (shapes, strides, etc.).
* @param workspace Device workspace memory.
* @param output Output tensor buffer.
* @param inputs Vector of input tensor pointers.
* @param queue BANG queue for asynchronous execution.
* @param args Additional arguments for the operator.
* @return infiniStatus_t Status indicating success or failure.
*/
template
<
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
calculateImpl
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
cnrtQueue_t
queue
,
Args
&&
...
args
)
{
auto
output_size
=
info
.
getOutputSize
();
if
(
output_size
==
0
)
{
return
INFINI_STATUS_SUCCESS
;
}
// Device pointers for metadata
const
void
**
d_inputs_arr
=
nullptr
;
const
bool
*
d_input_contiguous
=
nullptr
;
const
bool
*
d_input_broadcasted
=
nullptr
;
const
size_t
*
d_output_shape
=
nullptr
;
const
ptrdiff_t
*
d_output_strides
=
nullptr
;
const
size_t
*
d_input_shapes
=
nullptr
;
const
ptrdiff_t
*
d_input_strides
=
nullptr
;
// Copy metadata to device and setup pointers
CHECK_STATUS
(
infoToDevice
<
N
>
(
info
,
workspace
,
inputs
.
data
(),
d_inputs_arr
,
d_input_contiguous
,
d_input_broadcasted
,
d_output_shape
,
d_output_strides
,
d_input_shapes
,
d_input_strides
));
// Launch the elementwise kernel
Op
::
template
launch
<
Tdata
>(
output_size
,
info
.
getNdim
(),
info
.
isOutputContiguous
(),
reinterpret_cast
<
const
void
*>
(
d_input_contiguous
),
reinterpret_cast
<
const
void
*>
(
d_input_broadcasted
),
reinterpret_cast
<
const
void
*>
(
d_output_shape
),
reinterpret_cast
<
const
void
*>
(
d_input_shapes
),
reinterpret_cast
<
const
void
*>
(
d_output_strides
),
reinterpret_cast
<
const
void
*>
(
d_input_strides
),
output
,
reinterpret_cast
<
const
void
*
const
*>
(
d_inputs_arr
),
queue
,
internal
,
args
...);
// Synchronize queue to ensure completion
CNRT_CHECK
(
cnrtQueueSync
(
queue
));
return
INFINI_STATUS_SUCCESS
;
}
private:
/**
* @brief Transfers elementwise operation metadata to device memory.
*
* @tparam N Number of input tensors.
*
* @param info Elementwise operation metadata.
* @param workspace Device workspace memory.
* @param h_inputs_arr Host array of input pointers.
* @param d_inputs_arr Output reference to device input pointers.
* @param d_input_contiguous Output reference to contiguous flags.
* @param d_input_broadcasted Output reference to broadcast flags.
* @param d_output_shape Output reference to output shape.
* @param d_output_strides Output reference to output strides.
* @param d_input_shapes Output reference to input shapes.
* @param d_input_strides Output reference to input strides.
* @return infiniStatus_t Status indicating success or failure.
*/
template
<
size_t
N
>
infiniStatus_t
infoToDevice
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
const
void
*
const
*
h_inputs_arr
,
const
void
**&
d_inputs_arr
,
const
bool
*&
d_input_contiguous
,
const
bool
*&
d_input_broadcasted
,
const
size_t
*&
d_output_shape
,
const
ptrdiff_t
*&
d_output_strides
,
const
size_t
*&
d_input_shapes
,
const
ptrdiff_t
*&
d_input_strides
)
const
{
constexpr
auto
input_size
=
N
;
const
auto
ndim
=
info
.
getNdim
();
constexpr
auto
input_arr_size
=
N
*
sizeof
(
*
h_inputs_arr
);
const
int8_t
*
info_meta_start
=
info
.
getMetaStart
();
const
int8_t
*
d_meta_start
=
reinterpret_cast
<
int8_t
*>
(
workspace
)
+
input_arr_size
;
// Copy input pointer array and metadata to device
CNRT_CHECK
(
cnrtMemcpy
(
workspace
,
(
void
*
)
h_inputs_arr
,
input_arr_size
,
CNRT_MEM_TRANS_DIR_HOST2DEV
));
CNRT_CHECK
(
cnrtMemcpy
((
void
*
)
d_meta_start
,
(
void
*
)
info_meta_start
,
info
.
getMetaMemSize
(),
CNRT_MEM_TRANS_DIR_HOST2DEV
));
// Setup pointers to device memory regions
d_inputs_arr
=
reinterpret_cast
<
const
void
**>
(
workspace
);
d_output_shape
=
reinterpret_cast
<
const
size_t
*>
(
d_meta_start
);
d_output_strides
=
reinterpret_cast
<
const
ptrdiff_t
*>
(
d_output_shape
+
ndim
);
d_input_shapes
=
reinterpret_cast
<
const
size_t
*>
(
d_output_strides
+
ndim
);
d_input_strides
=
reinterpret_cast
<
const
ptrdiff_t
*>
(
d_input_shapes
+
input_size
*
ndim
);
d_input_contiguous
=
reinterpret_cast
<
const
bool
*>
(
d_input_strides
+
input_size
*
ndim
);
d_input_broadcasted
=
reinterpret_cast
<
const
bool
*>
(
d_input_contiguous
+
input_size
);
return
INFINI_STATUS_SUCCESS
;
}
};
/**
* @brief Creates a DeviceImpl instance for BANG device.
*
* @tparam Args Argument types for Opaque construction.
* @param args Arguments forwarded to Opaque constructor.
* @return utils::Result<DeviceImpl*> Result containing new DeviceImpl instance.
*/
template
<
typename
...
Args
>
utils
::
Result
<
DeviceImpl
*>
DeviceImpl
::
create
(
Args
&&
...
args
)
{
auto
opaque
=
std
::
make_shared
<
Opaque
>
(
std
::
forward
<
Args
>
(
args
)...);
return
utils
::
Result
<
DeviceImpl
*>
(
new
DeviceImpl
(
opaque
));
}
/**
* @brief Calculates elementwise operation for BANG device.
*
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for the operator.
*
* @param info Elementwise operation metadata.
* @param workspace Device workspace memory.
* @param output Output tensor buffer.
* @param inputs Vector of input tensor pointers.
* @param queue BANG queue (as void*).
* @param args Additional arguments for the operator.
* @return infiniStatus_t Status indicating success or failure.
*/
template
<
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
DeviceImpl
::
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
queue
,
Args
&&
...
args
)
{
constexpr
size_t
N
=
Op
::
num_inputs
;
return
_opaque
->
calculateImpl
<
N
,
Op
,
Tdata
>
(
info
,
workspace
,
output
,
inputs
,
reinterpret_cast
<
cnrtQueue_t
>
(
queue
),
std
::
forward
<
Args
>
(
args
)...);
}
}
// namespace op::elementwise::bang
/**
* @brief Macro for declaring BANG kernel interface.
*
* @param OpName Name of the elementwise operation.
*/
#define LAUNCH_ELEMENTWISE_KERNEL(OpName) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
cnrtQueue_t queue, \
const std::shared_ptr<device::bang::Handle::Internal> &internal, \
Args... args);
#endif // __INFINIOP_ELEMENTWISE_BANG_H__
src/infiniop/elementwise/bang/elementwise_bang_api.h
0 → 100644
View file @
adbda4c4
#ifndef __INFINIOP_ELEMENTWISE_BANG_API_H__
#define __INFINIOP_ELEMENTWISE_BANG_API_H__
#include "../elementwise.h"
namespace
op
::
elementwise
::
bang
{
/**
* @brief BANG device implementation for elementwise operations.
*
* Provides interface for creating and executing elementwise operations on BANG devices.
*/
class
DeviceImpl
final
{
struct
Opaque
;
std
::
shared_ptr
<
Opaque
>
_opaque
;
DeviceImpl
(
std
::
shared_ptr
<
Opaque
>
opaque
)
:
_opaque
(
std
::
move
(
opaque
))
{}
public:
~
DeviceImpl
()
=
default
;
/**
* @brief Creates a DeviceImpl instance.
*
* @tparam Args Argument types for construction.
* @param args Arguments forwarded to implementation.
* @return utils::Result<DeviceImpl*> Result containing new instance.
*/
template
<
typename
...
Args
>
static
utils
::
Result
<
DeviceImpl
*>
create
(
Args
&&
...
args
);
/**
* @brief Executes elementwise operation on BANG device.
*
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for the operator.
*
* @param info Elementwise operation metadata.
* @param workspace Device workspace memory.
* @param output Output tensor buffer.
* @param inputs Vector of input tensor pointers.
* @param queue BANG queue (as void*).
* @param args Additional arguments for the operator.
* @return infiniStatus_t Status indicating success or failure.
*/
template
<
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
calculate
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
const
std
::
vector
<
const
void
*>
&
inputs
,
void
*
queue
,
Args
&&
...
args
);
};
}
// namespace op::elementwise::bang
/**
* @brief Macro for creating BANG elementwise operation descriptor.
*
* @param HANDLE Device handle.
* @param DTYPE Output data type.
* @param OUT_DESC Output tensor descriptor.
* @param INPUT_DESC_VEC Vector of input tensor descriptors.
*/
#define CREATE_ELEMENTWISE_BANG_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::bang::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_BANG_API_H__
src/infiniop/elementwise/bang/elementwise_bang_kernel.h
0 → 100644
View file @
adbda4c4
#ifndef __INFINIOP_ELEMENTWISE_BANG_KERNEL_MLU__
#define __INFINIOP_ELEMENTWISE_BANG_KERNEL_MLU__
#include "../../devices/bang/bang_kernel_common.h"
#include "../../devices/bang/common_bang.h"
using
namespace
device
::
bang
::
kernel
;
/**
* @brief Core elementwise operation implementation for BANG device.
*
* @tparam N Number of input tensors.
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for operator.
*
* @param typed_inputs Array of typed input pointers.
* @param output Output tensor pointer.
* @param nram_buf NRAM buffer for temporary storage.
* @param input_indexes Precomputed input indexes.
* @param output_index Starting output index.
* @param num_elements Number of elements to process.
* @param output_contiguous Whether output is contiguous.
* @param input_contiguous Array indicating input contiguity.
* @param ndim Number of dimensions.
* @param input_shape Input shape in global memory.
* @param input_strides Input strides in global memory.
* @param output_shape Output shape in global memory.
* @param output_strides Output strides in global memory.
* @param indexer Input indexer helper.
* @param start_idx Starting index for this task.
* @param args Additional arguments for operator.
*/
template
<
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
__mlu_device__
void
launchOp
(
Tdata
**
typed_inputs
,
Tdata
*
output
,
Tdata
*
nram_buf
,
size_t
*
input_indexes
,
size_t
output_index
,
size_t
num_elements
,
bool
output_contiguous
,
const
bool
*
input_contiguous
,
const
bool
*
input_broadcasted
,
size_t
ndim
,
const
size_t
*
input_shapes
,
const
ptrdiff_t
*
input_strides
,
const
size_t
*
output_shape
,
const
ptrdiff_t
*
output_strides
,
InputIndexer
indexer
,
size_t
start_idx
,
Args
...
args
)
{
static_assert
(
N
==
Op
::
num_inputs
,
"template N is not equal to Op::num_inputs!"
);
// NRAM memory planning
const
size_t
nram_usable
=
NRAM_MAX_SIZE
-
(
ALIGN_SIZE
*
(
N
+
1
));
const
size_t
max_batch
=
nram_usable
/
((
N
+
1
)
*
sizeof
(
Tdata
));
size_t
processed
=
0
;
while
(
processed
<
num_elements
)
{
size_t
curr_batch
=
std
::
min
(
max_batch
,
num_elements
-
processed
);
// Align memory address
Tdata
*
aligned_buf
=
reinterpret_cast
<
Tdata
*>
(
(
reinterpret_cast
<
size_t
>
(
nram_buf
)
+
ALIGN_SIZE
-
1
)
&
~
(
ALIGN_SIZE
-
1
));
// 1. Copy input data to NRAM
Tdata
*
input_buffers
[
N
];
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
input_buffers
[
i
]
=
aligned_buf
+
i
*
max_batch
;
if
(
input_contiguous
[
i
])
{
// Contiguous case - bulk copy
__memcpy_async
(
input_buffers
[
i
],
typed_inputs
[
i
]
+
input_indexes
[
i
]
+
processed
,
curr_batch
*
sizeof
(
Tdata
),
GDRAM2NRAM
);
}
else
{
// Non-contiguous case - copy in contiguous chunks
nonContiguousMemcpy
<
Tdata
>
(
input_buffers
[
i
],
typed_inputs
[
i
],
GDRAM2NRAM
,
indexer
,
i
,
processed
,
curr_batch
,
start_idx
,
ndim
,
input_shapes
+
i
*
ndim
,
input_strides
+
i
*
ndim
,
true
);
}
}
__sync_io
();
// 2. Execute operation
Tdata
*
output_buffer
=
aligned_buf
+
N
*
max_batch
;
Op
op
;
op
(
output_buffer
,
input_buffers
[
0
],
input_buffers
[
1
],
curr_batch
,
args
...);
__sync_compute
();
// 3. Write back results
if
(
output_contiguous
)
{
// Contiguous output - bulk copy
__memcpy_async
(
output
+
output_index
+
processed
,
output_buffer
,
curr_batch
*
sizeof
(
Tdata
),
NRAM2GDRAM
);
}
else
{
// Non-contiguous output - copy in contiguous chunks
nonContiguousMemcpy
<
Tdata
>
(
output
,
output_buffer
,
NRAM2GDRAM
,
indexer
,
0
,
// unused for output
processed
,
curr_batch
,
start_idx
,
ndim
,
output_shape
,
output_strides
,
false
);
}
processed
+=
curr_batch
;
}
}
/**
* @brief BANG kernel for elementwise operations.
*
* @tparam N Number of input tensors.
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for operator.
*
* @param output_size Total output elements.
* @param ndim Number of dimensions.
* @param output_contiguous Whether output is contiguous.
* @param input_contiguous Input contiguity flags in global memory.
* @param input_broadcasted Input broadcast flags in global memory.
* @param output_shape Output shape in global memory.
* @param input_shapes Input shapes in global memory.
* @param output_strides Output strides in global memory.
* @param input_strides Input strides in global memory.
* @param output Output tensor pointer.
* @param inputs Array of input pointers.
* @param args Additional arguments for operator.
*/
template
<
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
__mlu_global__
void
elementwiseKernel
(
size_t
output_size
,
size_t
ndim
,
bool
output_contiguous
,
const
bool
*
input_contiguous
,
const
bool
*
input_broadcasted
,
const
size_t
*
output_shape
,
const
size_t
*
input_shapes
,
const
ptrdiff_t
*
output_strides
,
const
ptrdiff_t
*
input_strides
,
Tdata
*
output
,
const
void
*
const
*
inputs
,
Args
...
args
)
{
// Cast input pointers to the correct type
Tdata
*
typed_inputs
[
N
];
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
typed_inputs
[
i
]
=
reinterpret_cast
<
Tdata
*>
(
const_cast
<
void
*>
(
inputs
[
i
]));
}
// Calculate workload per task
size_t
elements_per_task
=
(
output_size
+
taskDim
-
1
)
/
taskDim
;
size_t
start_idx
=
taskId
*
elements_per_task
;
size_t
end_idx
=
std
::
min
(
start_idx
+
elements_per_task
,
output_size
);
size_t
num_elements
=
end_idx
>
start_idx
?
end_idx
-
start_idx
:
0
;
if
(
num_elements
==
0
)
{
return
;
}
// Allocate NRAM buffer (shared by all inputs and output)
__nram__
Tdata
nram_buf
[
NRAM_MAX_SIZE
/
sizeof
(
Tdata
)];
// Get output index
size_t
output_index
=
getOutputIndex
(
start_idx
,
output_contiguous
,
ndim
,
output_shape
,
output_strides
);
// Create input indexer
InputIndexer
indexer
{
static_cast
<
size_t
>
(
start_idx
),
ndim
,
input_contiguous
,
input_broadcasted
,
input_shapes
,
input_strides
,
output_strides
};
// Get index offsets for each operand
size_t
input_indexes
[
N
];
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
input_indexes
[
i
]
=
indexer
(
i
,
0
);
}
// Launch the operation with all required parameters
launchOp
<
N
,
Op
,
Tdata
>
(
typed_inputs
,
output
,
nram_buf
,
input_indexes
,
output_index
,
num_elements
,
output_contiguous
,
input_contiguous
,
input_broadcasted
,
ndim
,
input_shapes
,
input_strides
,
output_shape
,
output_strides
,
indexer
,
start_idx
,
args
...);
}
/**
* @brief Intermediate layer that determines optimal launch configuration before calling elementwiseKernel.
*
* @tparam N Number of input tensors.
* @tparam Op Operator functor type.
* @tparam Tdata Data type for inputs and output.
* @tparam Args Additional arguments for operator.
*/
template
<
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
void
launchElementwiseKernelWrapper
(
size_t
output_size
,
size_t
ndim
,
bool
output_contiguous
,
const
bool
*
input_contiguous
,
const
bool
*
input_broadcasted
,
const
size_t
*
output_shape
,
const
size_t
*
input_shapes
,
const
ptrdiff_t
*
output_strides
,
const
ptrdiff_t
*
input_strides
,
Tdata
*
output
,
const
void
*
const
*
inputs
,
cnrtQueue_t
queue
,
const
std
::
shared_ptr
<
device
::
bang
::
Handle
::
Internal
>
&
internal
,
Args
...
args
)
{
// Get hardware information from internal handle
int
core_per_cluster
=
internal
->
getCorePerCluster
();
int
cluster_count
=
internal
->
getClusterCount
();
// Set kernel launch dimensions
cnrtDim3_t
dim
;
dim
.
x
=
core_per_cluster
;
dim
.
y
=
cluster_count
;
dim
.
z
=
1
;
// Choose kernel type based on problem characteristics
cnrtFunctionType_t
func_type
=
CNRT_FUNC_TYPE_BLOCK
;
if
(
output_size
>
1024
*
1024
&&
output_contiguous
)
{
// For large contiguous operations, use UNION type
func_type
=
CNRT_FUNC_TYPE_UNION1
;
}
// Launch the kernel with optimal configuration
elementwiseKernel
<
N
,
Op
,
Tdata
><<<
dim
,
func_type
,
queue
>>>
(
output_size
,
ndim
,
output_contiguous
,
input_contiguous
,
input_broadcasted
,
output_shape
,
input_shapes
,
output_strides
,
input_strides
,
output
,
inputs
,
args
...);
}
/**
* @brief Macro for implementing elementwise kernel launch.
*
* @param OpName Name of the operation.
* @param Op Operator functor type.
*/
#define LAUNCH_ELEMENTWISE_KERNEL_IMPL(OpName, Op) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
cnrtQueue_t queue, \
const std::shared_ptr<device::bang::Handle::Internal> &internal, \
Args... args) { \
launchElementwiseKernelWrapper<Op::num_inputs, Op, Tdata>( \
output_size, ndim, output_contiguous, \
reinterpret_cast<const bool *>(input_contiguous), \
reinterpret_cast<const bool *>(input_broadcasted), \
reinterpret_cast<const size_t *>(output_shape), \
reinterpret_cast<const size_t *>(input_shapes), \
reinterpret_cast<const ptrdiff_t *>(output_strides), \
reinterpret_cast<const ptrdiff_t *>(input_strides), \
reinterpret_cast<Tdata *>(output), inputs, queue, internal, args...); \
}
/**
* @brief Macro for instantiating elementwise kernel for specific types.
*
* @param OpName Name of the operation.
* @param T Data type.
* @param ... Additional template arguments.
*/
/**
* @brief Macro for instantiating elementwise kernel for specific types.
*
* @param OpName Name of the operation.
* @param T Data type.
* @param ... Additional template arguments.
*/
#define LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(OpName, T, ...) \
template void launch##OpName##Kernel<T, ##__VA_ARGS__>( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
cnrtQueue_t queue, \
const std::shared_ptr<device::bang::Handle::Internal> &internal, \
##__VA_ARGS__);
#endif
src/infiniop/ops/add/bang/add_bang.h
0 → 100644
View file @
adbda4c4
#ifndef __ADD_BANG_API_H__
#define __ADD_BANG_API_H__
#include "../../../elementwise/bang/elementwise_bang.h"
ELEMENTWISE_DESCRIPTOR
(
add
,
bang
)
#endif // __ADD_BANG_API_H__
src/infiniop/ops/add/bang/add_bang.mlu
0 → 100644
View file @
adbda4c4
#include "add_bang.h"
// Operator Interface Declaration
LAUNCH_ELEMENTWISE_KERNEL(Add)
namespace op::add::bang {
typedef struct AddOp {
static constexpr size_t num_inputs = 2;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchAddKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} AddOp;
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create Bang elementwise descriptor
CREATE_ELEMENTWISE_BANG_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *queue) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<AddOp, half>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_BF16:
return _device_info->calculate<AddOp, bfloat16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F32:
return _device_info->calculate<AddOp, float>(_info, workspace, output, inputs, queue);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add::bang
src/infiniop/ops/add/bang/add_bang_internal.mlu
0 → 100644
View file @
adbda4c4
#ifndef __ADD_BANG_INTERNAL_H__
#define __ADD_BANG_INTERNAL_H__
#include "../../../elementwise/bang/elementwise_bang_kernel.h"
typedef struct AddOp {
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__mlu_device__ void operator()(T *out, const T *a, const T *b, size_t num_elements) const {
if constexpr (std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float>) {
__bang_add(out, a, b, num_elements);
} else {
out = a + b;
}
}
} AddOp;
LAUNCH_ELEMENTWISE_KERNEL_IMPL(Add, AddOp)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, half)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, bfloat16_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, float)
#endif // __ADD_BANG_INTERNAL_H__
src/infiniop/ops/add/operator.cc
View file @
adbda4c4
...
...
@@ -14,6 +14,9 @@
#ifdef ENABLE_KUNLUN_API
#include "kunlun/add_kunlun.h"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/add_bang.h"
#endif
__C
infiniStatus_t
infiniopCreateAddDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -48,6 +51,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -78,6 +84,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_CAMBRICON_API
GET
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -118,6 +127,9 @@ __C infiniStatus_t infiniopAdd(
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -151,6 +163,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/swiglu/bang/swiglu_bang.h
0 → 100644
View file @
adbda4c4
#ifndef __SWIGLU_BANG_API_H__
#define __SWIGLU_BANG_API_H__
#include "../../../elementwise/bang/elementwise_bang.h"
ELEMENTWISE_DESCRIPTOR
(
swiglu
,
bang
)
#endif // __SWIGLU_BANG_API_H__
src/infiniop/ops/swiglu/bang/swiglu_bang.mlu
0 → 100644
View file @
adbda4c4
#include "swiglu_bang.h"
// Operator Interface Declaration
LAUNCH_ELEMENTWISE_KERNEL(SwiGLU)
namespace op::swiglu::bang {
typedef struct SwiGLUOp {
static constexpr size_t num_inputs = 2;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchSwiGLUKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} SwiGLUOp;
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc_vec.at(0);
const auto &gate_desc = input_desc_vec.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create Bang elementwise descriptor
CREATE_ELEMENTWISE_BANG_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *queue) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<SwiGLUOp, half>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_BF16:
return _device_info->calculate<SwiGLUOp, bfloat16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F32:
return _device_info->calculate<SwiGLUOp, float>(_info, workspace, output, inputs, queue);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::bang
src/infiniop/ops/swiglu/bang/swiglu_bang_internal.mlu
0 → 100644
View file @
adbda4c4
#ifndef __SWIGLU_BANG_INTERNAL_H__
#define __SWIGLU_BANG_INTERNAL_H__
#include "../../../elementwise/bang/elementwise_bang_kernel.h"
#include "bang.h"
#include "bang_device_functions.h"
typedef struct SwiGLUOp {
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__mlu_device__ void operator()(T *out, const T *up, const T *gate, size_t num_elements) const {
if constexpr (std::is_same_v<T, half> || std::is_same_v<T, bfloat16_t>) {
__bang_active_sigmoid(out, gate, num_elements);
__bang_mul(out, out, gate, num_elements);
__bang_mul(out, out, up, num_elements);
} else if constexpr (std::is_same_v<T, float>) {
__bang_neg(out, gate, num_elements);
__bang_active_exphp(out, out, num_elements);
__bang_add_scalar(out, out, 1.0f, num_elements);
__bang_div(out, gate, out, num_elements);
__bang_mul(out, up, out, num_elements);
} else {
for (size_t i = 0; i < num_elements; ++i) {
out[i] = up[i] * gate[i] / (1.0 + std::exp(-gate[i]));
}
}
}
} SwiGLUOp;
LAUNCH_ELEMENTWISE_KERNEL_IMPL(SwiGLU, SwiGLUOp)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(SwiGLU, half)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(SwiGLU, bfloat16_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(SwiGLU, float)
#endif // __SWIGLU_BANG_INTERNAL_H__
src/infiniop/ops/swiglu/operator.cc
View file @
adbda4c4
...
...
@@ -14,6 +14,9 @@
#ifdef ENABLE_METAX_API
#include "metax/swiglu_metax.h"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/swiglu_bang.h"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/swiglu_ascend.h"
#endif
...
...
@@ -51,23 +54,12 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangCreateSwiGLUDescriptor
((
BangHandle_t
)
handle
,
(
SwiGLUBangDescriptor_t
*
)
desc_ptr
,
c_desc
,
a_desc
,
b_desc
);
}
#ifdef ENABLE_CAMBRICON_API
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
return
macaCreateSwiGLUDescriptor
((
MacaHandle_t
)
handle
,
(
SwiGLUMacaDescriptor_t
*
)
desc_ptr
,
c_desc
,
a_desc
,
b_desc
);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
return
musaCreateSwiGLUDescriptor
(
...
...
@@ -104,10 +96,8 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangGetSwiGLUWorkspaceSize
((
SwiGLUBangDescriptor_t
)
desc
,
size
);
}
#ifdef ENABLE_CAMBRICON_API
GET
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
GET
(
INFINI_DEVICE_ASCEND
,
ascend
);
...
...
@@ -155,18 +145,12 @@ __C infiniStatus_t infiniopSwiGLU(
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangSwiGLU
((
SwiGLUBangDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
}
#ifdef ENABLE_CAMBRICON_API
CALCULATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
return
macaSwiGLU
((
SwiGLUMacaDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
return
musaSwiGLU
((
SwiGLUMusaDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
...
...
@@ -204,18 +188,12 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangDestroySwiGLUDescriptor
((
SwiGLUBangDescriptor_t
)
desc
);
}
#ifdef ENABLE_CAMBRICON_API
DELETE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
#ifdef ENABLE_ASCEND_API
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
)
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
return
macaDestroySwiGLUDescriptor
((
SwiGLUMacaDescriptor_t
)
desc
);
#endif
#ifdef ENABLE_MTHREADS_GPU
case
DevMthreadsGpu
:
return
musaDestroySwiGLUDescriptor
((
SwiGLUMusaDescriptor_t
)
desc
);
...
...
test/infiniop/libinfiniop/utils.py
View file @
adbda4c4
...
...
@@ -605,11 +605,9 @@ def get_test_devices(args):
def
get_sync_func
(
device
):
import
torch
device_str
=
torch_device_map
[
device
]
if
device
==
InfiniDeviceEnum
.
CPU
:
if
device
==
InfiniDeviceEnum
.
CPU
or
device
==
InfiniDeviceEnum
.
CAMBRICON
:
sync
=
None
else
:
sync
=
getattr
(
torch
,
device_str
).
synchronize
sync
=
getattr
(
torch
,
infiniDeviceEnum_str_map
[
device
]
).
synchronize
return
sync
test/infiniop/swiglu.py
View file @
adbda4c4
...
...
@@ -64,8 +64,8 @@ _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-3
},
InfiniDtype
.
BF16
:
{
"atol"
:
5e-3
,
"rtol"
:
5
e-
3
},
InfiniDtype
.
F32
:
{
"atol"
:
2
e-
7
,
"rtol"
:
1e-
7
},
InfiniDtype
.
BF16
:
{
"atol"
:
5e-3
,
"rtol"
:
1
e-
2
},
InfiniDtype
.
F32
:
{
"atol"
:
1
e-
5
,
"rtol"
:
1e-
5
},
}
DEBUG
=
False
...
...
xmake/bang.lua
View file @
adbda4c4
local
NEUWARE_HOME
=
os.getenv
(
"NEUWARE_HOME"
)
or
"/usr/local/neuware"
add_includedirs
(
path
.
join
(
NEUWARE_HOME
,
"include"
))
add_includedirs
(
path
.
join
(
NEUWARE_HOME
,
"include"
)
,
{
public
=
true
}
)
add_linkdirs
(
path
.
join
(
NEUWARE_HOME
,
"lib64"
))
add_linkdirs
(
path
.
join
(
NEUWARE_HOME
,
"lib"
))
add_links
(
"libcnrt.so"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment