Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
10c9525f
Commit
10c9525f
authored
Mar 24, 2025
by
Zimin Li
Browse files
issue/46: Add binary infrastructure and refactor swiglu cpu using binary
parent
150dde0c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
184 additions
and
252 deletions
+184
-252
src/infiniop/devices/cpu/common_cpu.cc
src/infiniop/devices/cpu/common_cpu.cc
+1
-1
src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc
src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc
+40
-115
src/infiniop/ops/swiglu/cpu/swiglu_cpu.h
src/infiniop/ops/swiglu/cpu/swiglu_cpu.h
+19
-16
src/infiniop/ops/swiglu/cpu/swiglu_cpu_api.h
src/infiniop/ops/swiglu/cpu/swiglu_cpu_api.h
+0
-25
src/infiniop/ops/swiglu/operator.cc
src/infiniop/ops/swiglu/operator.cc
+109
-80
test/infiniop/swiglu.py
test/infiniop/swiglu.py
+15
-15
No files found.
src/infiniop/devices/cpu/common_cpu.cc
View file @
10c9525f
...
@@ -19,7 +19,7 @@ size_t indexToOffset(
...
@@ -19,7 +19,7 @@ size_t indexToOffset(
const
size_t
*
shape
,
const
size_t
*
shape
,
const
ptrdiff_t
*
strides
)
{
const
ptrdiff_t
*
strides
)
{
size_t
res
=
0
;
size_t
res
=
0
;
for
(
size_t
i
=
ndim
;
i
--
>
=
0
;)
{
for
(
size_t
i
=
ndim
;
i
--
>
0
;)
{
res
+=
(
flat_index
%
shape
[
i
])
*
strides
[
i
];
res
+=
(
flat_index
%
shape
[
i
])
*
strides
[
i
];
flat_index
/=
shape
[
i
];
flat_index
/=
shape
[
i
];
}
}
...
...
src/infiniop/ops/swiglu/cpu/swiglu_cpu.cc
View file @
10c9525f
#include "swiglu_cpu.h"
#include "swiglu_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include <cmath>
#include <cstdlib>
infiniopStatus_t
cpuCreateSwiGLUDescriptor
(
namespace
op
::
swiglu
::
cpu
{
infiniopCpuHandle_t
handle
,
infiniopSwiGLUCpuDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
auto
const
out
=
c_desc
,
up
=
a_desc
,
gate
=
b_desc
;
auto
dtype
=
out
->
dtype
;
Descriptor
::~
Descriptor
()
=
default
;
// Check dtypes
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
up_desc
,
infiniopTensorDescriptor_t
gate_desc
)
{
constexpr
infiniDtype_t
SUPPORTED_DTYPES
[]
=
{
auto
handle
=
reinterpret_cast
<
device
::
cpu
::
Handle
*>
(
handle_
);
constexpr
std
::
array
<
infiniDtype_t
,
3
>
SUPPORTED_DTYPES
=
{
INFINI_DTYPE_F16
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
,
INFINI_DTYPE_F64
,
};
};
auto
supported
=
false
;
for
(
auto
supported_dtype
:
SUPPORTED_DTYPES
)
{
if
(
dtype
==
supported_dtype
)
{
supported
=
true
;
break
;
}
}
if
(
!
supported
||
gate
->
dtype
!=
dtype
||
up
->
dtype
!=
dtype
)
{
return
INFINIOP_STATUS_BAD_TENSOR_DTYPE
;
}
// Check shapes
// Perform generic binary operator check
CHECK_STATUS
(
op
::
common_cpu
::
binary_op
::
check
(
out_desc
,
up_desc
,
gate_desc
,
SUPPORTED_DTYPES
,
true
,
true
));
if
(
out
->
ndim
!=
2
||
gate
->
ndim
!=
2
||
up
->
ndim
!=
2
)
{
return
INFINIOP_STATUS_BAD_TENSOR_SHAPE
;
}
auto
const
n
=
out
->
shape
[
0
],
d
=
out
->
shape
[
1
],
n_g
=
gate
->
shape
[
0
],
d_g
=
gate
->
shape
[
1
],
n_u
=
up
->
shape
[
0
],
d_u
=
up
->
shape
[
1
];
if
(
n_g
!=
n
||
n_u
!=
n
||
d_g
!=
d
||
d_u
!=
d
)
{
return
INFINIOP_STATUS_BAD_TENSOR_SHAPE
;
}
// Create descriptor
// Create descriptor
*
desc_ptr
=
new
Descriptor
(
*
desc_ptr
=
new
SwiGLUCpuDescriptor
{
out_desc
->
dtype
(),
INFINI_DEVICE_CPU
,
{
out_desc
,
up_desc
,
gate_desc
},
dtype
,
nullptr
,
n
,
handle
->
device
,
d
,
handle
->
device_id
);
out
->
strides
[
0
],
out
->
strides
[
1
],
return
INFINI_STATUS_SUCCESS
;
gate
->
strides
[
0
],
gate
->
strides
[
1
],
up
->
strides
[
0
],
up
->
strides
[
1
],
};
return
INFINIOP_STATUS_SUCCESS
;
}
template
<
class
T
>
T
sigmoid
(
T
x
)
{
return
1
/
(
1
+
std
::
exp
(
-
x
));
}
}
template
<
class
T
>
infiniStatus_t
Descriptor
::
calculate
(
T
swiglu
(
T
gate
,
T
up
)
{
void
*
c
,
return
gate
*
sigmoid
(
gate
)
*
up
;
const
void
*
a
,
}
const
void
*
b
,
void
*
stream
)
const
{
template
<
class
T
>
void
swiglu_ptr
(
uint8_t
*
out
,
uint8_t
const
*
gate
,
uint8_t
const
*
up
)
{
auto
out_
=
reinterpret_cast
<
T
*>
(
out
);
auto
gate_
=
reinterpret_cast
<
T
const
*>
(
gate
);
auto
up_
=
reinterpret_cast
<
T
const
*>
(
up
);
*
out_
=
swiglu
(
*
gate_
,
*
up_
);
}
template
<
>
void
swiglu_ptr
<
uint16_t
>
(
uint8_t
*
out
,
uint8_t
const
*
gate
,
uint8_t
const
*
up
)
{
auto
out_
=
reinterpret_cast
<
uint16_t
*>
(
out
);
auto
gate_
=
reinterpret_cast
<
uint16_t
const
*>
(
gate
);
auto
up_
=
reinterpret_cast
<
uint16_t
const
*>
(
up
);
*
out_
=
f32_to_f16
(
swiglu
(
f16_to_f32
(
*
gate_
),
f16_to_f32
(
*
up_
)));
}
infiniopStatus_t
cpuSwiGLU
(
infiniopSwiGLUCpuDescriptor_t
desc
,
void
*
c
,
void
const
*
a
,
void
const
*
b
)
{
auto
out
=
reinterpret_cast
<
uint8_t
*>
(
c
);
switch
(
_dtype
)
{
auto
up
=
reinterpret_cast
<
uint8_t
const
*>
(
a
);
auto
gate
=
reinterpret_cast
<
uint8_t
const
*>
(
b
);
auto
const
unit
=
infiniSizeof
(
desc
->
dtype
);
for
(
size_t
i
=
0
;
i
<
desc
->
n
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
desc
->
d
;
++
j
)
{
auto
out_
=
out
+
(
i
*
desc
->
s_no
+
j
*
desc
->
s_do
)
*
unit
;
auto
gate_
=
gate
+
(
i
*
desc
->
s_ng
+
j
*
desc
->
s_dg
)
*
unit
;
auto
up_
=
up
+
(
i
*
desc
->
s_nu
+
j
*
desc
->
s_du
)
*
unit
;
switch
(
desc
->
dtype
)
{
case
INFINI_DTYPE_F16
:
case
INFINI_DTYPE_F16
:
swiglu_ptr
<
uint16_t
>
(
out_
,
gate_
,
up_
);
op
::
common_cpu
::
binary_op
::
calculate
<
fp16_t
,
SwiGLUOp
>
(
_info
,
c
,
a
,
b
);
break
;
break
;
case
INFINI_DTYPE_F32
:
case
INFINI_DTYPE_F32
:
swiglu_ptr
<
float
>
(
out_
,
gate_
,
up_
);
op
::
common_cpu
::
binary_op
::
calculate
<
float
,
SwiGLUOp
>
(
_info
,
c
,
a
,
b
);
break
;
break
;
case
INFINI_DTYPE_F64
:
case
INFINI_DTYPE_F64
:
swiglu_ptr
<
double
>
(
out_
,
gate_
,
up_
);
op
::
common_cpu
::
binary_op
::
calculate
<
double
,
SwiGLUOp
>
(
_info
,
c
,
a
,
b
);
break
;
break
;
default:
default:
// unreachable
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
std
::
abort
();
}
}
}
}
return
INFINIOP_STATUS_SUCCESS
;
}
infiniopStatus_t
cpuDestroySwiGLUDescriptor
(
return
INFINI_STATUS_SUCCESS
;
infiniopSwiGLUCpuDescriptor_t
desc
)
{
delete
desc
;
return
INFINIOP_STATUS_SUCCESS
;
}
}
}
// namespace op::swiglu::cpu
src/infiniop/ops/swiglu/cpu/swiglu_cpu.h
View file @
10c9525f
#ifndef __
INFINIOP_
SWIGLU_CPU_H__
#ifndef __SWIGLU_CPU_H__
#define __
INFINIOP_
SWIGLU_CPU_H__
#define __SWIGLU_CPU_H__
#include ".
/swiglu_cpu_api
.h"
#include ".
./../../binary/cpu/binary
.h"
typedef
struct
SwiGLUCpuDescriptor
{
BINARY_DESCRIPTOR
(
swiglu
,
cpu
)
infiniDevice_t
device
;
infiniDtype_t
dtype
;
size_t
n
,
d
;
ptrdiff_t
s_no
,
// n stride of out
s_do
,
// d stride of out
s_ng
,
// n stride of gate
s_dg
,
// d stride of gate
s_nu
,
// n stride of up
s_du
;
// d stride of up
}
SwiGLUCpuDescriptor
;
#endif // __INFINIOP_SWIGLU_CPU_H__
struct
SwiGLUOp
{
private:
template
<
typename
T
>
T
sigmoid
(
const
T
&
x
)
const
{
return
1
/
(
1
+
std
::
exp
(
-
x
));
}
public:
template
<
typename
T
>
T
operator
()(
const
T
&
up
,
const
T
&
gate
)
const
{
return
gate
*
sigmoid
(
gate
)
*
up
;
}
};
#endif // __SWIGLU_CPU_H__
src/infiniop/ops/swiglu/cpu/swiglu_cpu_api.h
deleted
100644 → 0
View file @
150dde0c
#ifndef __INFINIOP_SWIGLU_CPU_API_H__
#define __INFINIOP_SWIGLU_CPU_API_H__
#include "../../../devices/cpu/cpu_handle.h"
#include "infiniop/operator.h"
struct
SwiGLUCpuDescriptor
;
typedef
struct
SwiGLUCpuDescriptor
*
infiniopSwiGLUCpuDescriptor_t
;
infiniopStatus_t
cpuCreateSwiGLUDescriptor
(
infiniopCpuHandle_t
handle
,
infiniopSwiGLUCpuDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
);
infiniopStatus_t
cpuSwiGLU
(
infiniopSwiGLUCpuDescriptor_t
desc
,
void
*
c
,
void
const
*
a
,
void
const
*
b
);
infiniopStatus_t
cpuDestroySwiGLUDescriptor
(
infiniopSwiGLUCpuDescriptor_t
desc
);
#endif // __INFINIOP_SWIGLU_CPU_API_H__
src/infiniop/ops/swiglu/operator.cc
View file @
10c9525f
...
@@ -2,112 +2,141 @@
...
@@ -2,112 +2,141 @@
#include "../../handle.h"
#include "../../handle.h"
#include "infiniop/ops/swiglu.h"
#include "infiniop/ops/swiglu.h"
#ifdef ENABLE_CPU_API
#include "cpu/swiglu_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/swiglu_cuda.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/swiglu_bang.h"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/swiglu_ascend.h"
#endif
#ifdef ENABLE_METAX_API
#include "maca/swiglu_maca.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/swiglu_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateSwiGLUDescriptor
(
__C
infiniStatus_t
infiniopCreateSwiGLUDescriptor
(
infiniopHandle_t
handle
,
infiniopSwiGLUDescriptor_t
*
desc_ptr
,
infiniopHandle_t
handle
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopSwiGLUDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
infiniopTensorDescriptor_t
b_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::swiglu::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::swiglu::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
a_desc, \
b_desc)
switch
(
handle
->
device
)
{
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
#ifdef ENABLE_CPU_API
case
INFINI_DEVICE_CPU
:
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
return
cpuCreateSwiGLUDescriptor
(
handle
,
(
infiniopSwiGLUCpuDescriptor_t
*
)
desc_ptr
,
c_desc
,
a_desc
,
b_desc
);
#endif
#ifdef ENABLE_NV_GPU
case
DevNvGpu
:
return
cudaCreateSwiGLUDescriptor
((
CudaHandle_t
)
handle
,
(
SwiGLUCudaDescriptor_t
*
)
desc_ptr
,
c_desc
,
a_desc
,
b_desc
);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case
DevCambriconMlu
:
{
return
bangCreateSwiGLUDescriptor
((
BangHandle_t
)
handle
,
(
SwiGLUBangDescriptor_t
*
)
desc_ptr
,
c_desc
,
a_desc
,
b_desc
);
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_CUDA_API
case
DevAscendNpu
:
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
return
ascendCreateSwiGLUDescriptor
(
(
AscendHandle_t
)
handle
,
(
SwiGLUAscendDescriptor_t
*
)
desc_ptr
,
c_desc
,
a_desc
,
b_desc
);
#endif
#ifdef ENABLE_METAX_GPU
case
DevMetaxGpu
:
{
return
macaCreateSwiGLUDescriptor
((
MacaHandle_t
)
handle
,
(
SwiGLUMacaDescriptor_t
*
)
desc_ptr
,
c_desc
,
a_desc
,
b_desc
);
}
#endif
#endif
#ifdef ENABLE_MTHREADS_GPU
#ifdef ENABLE_CAMBRICON_API
case
DevMthreadsGpu
:
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
return
musaCreateSwiGLUDescriptor
(
handle
,
(
SwiGLUMusaDescriptor_t
*
)
desc_ptr
,
c_desc
,
a_desc
,
b_desc
);
#endif
#endif
}
#ifdef ENABLE_ASCEND_API
CREATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
maca
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
;
}
__C
infiniStatus_t
infiniopSwiGLU
(
infiniopSwiGLUDescriptor_t
desc
,
void
*
c
,
#undef CREATE
const
void
*
a
,
const
void
*
b
,
}
__C
infiniStatus_t
infiniopSwiGLU
(
infiniopSwiGLUDescriptor_t
desc
,
void
*
c
,
const
void
*
a
,
const
void
*
b
,
void
*
stream
)
{
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \
->calculate(c, a, b, stream)
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU
case
DevCpu
:
#ifdef ENABLE_CPU_API
return
cpuSwiGLU
((
SwiGLUCpuDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
return
cudaSwiGLU
((
SwiGLUCudaDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_API
case
DevCambriconMlu
:
{
CALCULATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
return
bangSwiGLU
((
SwiGLUBangDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
return
ascendSwiGLU
((
SwiGLUAscendDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_API
case
DevMetaxGpu
:
CALCULATE
(
INFINI_DEVICE_METAX
,
maca
);
return
macaSwiGLU
((
SwiGLUMacaDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
#endif
#endif
#ifdef ENABLE_MTHREADS_GPU
#ifdef ENABLE_KUNLUN_API
case
DevMthreadsGpu
:
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
return
musaSwiGLU
((
SwiGLUMusaDescriptor_t
)
desc
,
c
,
a
,
b
,
stream
);
#endif
#endif
}
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
}
__C
infiniStatus_t
__C
infiniStatus_t
infiniopDestroySwiGLUDescriptor
(
infiniopSwiGLUDescriptor_t
desc
)
{
infiniopDestroySwiGLUDescriptor
(
infiniopSwiGLUDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU
case
DevCpu
:
#ifdef ENABLE_CPU_API
return
cpuDestroySwiGLUDescriptor
((
SwiGLUCpuDescriptor_t
)
desc
);
DELETE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_CUDA_API
case
DevNvGpu
:
DELETE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
return
cudaDestroySwiGLUDescriptor
((
SwiGLUCudaDescriptor_t
)
desc
);
#endif
#endif
#ifdef ENABLE_CAMBRICON_MLU
#ifdef ENABLE_CAMBRICON_API
case
DevCambriconMlu
:
{
DELETE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
return
bangDestroySwiGLUDescriptor
((
SwiGLUBangDescriptor_t
)
desc
);
}
#endif
#endif
#ifdef ENABLE_ASCEND_NPU
#ifdef ENABLE_ASCEND_API
case
DevAscendNpu
:
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
return
ascendDestroySwiGLUDescriptor
((
SwiGLUAscendDescriptor_t
)
desc
);
#endif
#endif
#ifdef ENABLE_METAX_GPU
#ifdef ENABLE_METAX_API
case
DevMetaxGpu
:
DELETE
(
INFINI_DEVICE_METAX
,
maca
);
return
macaDestroySwiGLUDescriptor
((
SwiGLUMacaDescriptor_t
)
desc
);
#endif
#endif
#ifdef ENABLE_MTHREADS_GPU
#ifdef ENABLE_KUNLUN_API
case
DevMthreadsGpu
:
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
return
musaDestroySwiGLUDescriptor
((
SwiGLUMusaDescriptor_t
)
desc
);
#endif
#endif
}
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
}
\ No newline at end of file
test/infiniop/swiglu.py
View file @
10c9525f
...
@@ -25,19 +25,25 @@ _TEST_CASES_ = [
...
@@ -25,19 +25,25 @@ _TEST_CASES_ = [
# shape, a_stride, b_stride, c_stride
# shape, a_stride, b_stride, c_stride
((
13
,
4
),
None
,
None
,
None
),
((
13
,
4
),
None
,
None
,
None
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
)),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
)),
#
((13, 4, 4), None, None, None),
((
13
,
4
,
4
),
None
,
None
,
None
),
#
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
)),
((
16
,
5632
),
None
,
None
,
None
),
((
16
,
5632
),
None
,
None
,
None
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
)),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
)),
#
((4, 4, 5632), None, None, None),
((
4
,
4
,
5632
),
None
,
None
,
None
),
#
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
((
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
)),
]
]
class
Inplace
(
Enum
):
OUT_OF_PLACE
=
auto
()
INPLACE_A
=
auto
()
INPLACE_B
=
auto
()
# Inplace options applied for each test case in _TEST_CASES_
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE
=
[
_INPLACE
=
[
"
Inplace.OUT_OF_PLACE
"
,
Inplace
.
OUT_OF_PLACE
,
"
Inplace.INPLACE_A
"
,
Inplace
.
INPLACE_A
,
"
Inplace.INPLACE_B
"
,
Inplace
.
INPLACE_B
,
]
]
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
...
@@ -48,7 +54,7 @@ _TEST_CASES = [
...
@@ -48,7 +54,7 @@ _TEST_CASES = [
]
]
# Data types used for testing
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
]
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
# Tolerance map for different data types
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
...
@@ -61,12 +67,6 @@ NUM_PRERUN = 10
...
@@ -61,12 +67,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS
=
1000
NUM_ITERATIONS
=
1000
class
Inplace
(
Enum
):
OUT_OF_PLACE
=
auto
()
INPLACE_A
=
auto
()
INPLACE_B
=
auto
()
class
SwiGLUDescriptor
(
Structure
):
class
SwiGLUDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
_fields_
=
[(
"device"
,
c_int32
)]
...
@@ -132,7 +132,7 @@ def test(
...
@@ -132,7 +132,7 @@ def test(
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for
tensor
in
[
a_tensor
,
b_tensor
,
c_tensor
]:
for
tensor
in
[
a_tensor
,
b_tensor
,
c_tensor
]:
tensor
.
des
criptor
.
contents
.
invalidate
(
)
tensor
.
des
troyDesc
(
lib
)
def
lib_swiglu
():
def
lib_swiglu
():
check_error
(
check_error
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment