Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d478a389
Commit
d478a389
authored
May 24, 2022
by
myamlak
Browse files
Review remarks: binary ops templated
parent
ac9ef30b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
54 additions
and
17 deletions
+54
-17
example/19_binary_elementwise/broadcast_add_2d.cpp
example/19_binary_elementwise/broadcast_add_2d.cpp
+6
-5
example/19_binary_elementwise/elementwise_add_1d.cpp
example/19_binary_elementwise/elementwise_add_1d.cpp
+5
-4
example/19_binary_elementwise/elementwise_add_4d.cpp
example/19_binary_elementwise/elementwise_add_4d.cpp
+5
-4
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
..._operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+34
-2
No files found.
example/19_binary_elementwise/broadcast_add_2d.cpp
View file @
d478a389
...
@@ -17,7 +17,8 @@ using ABDataType = F16;
...
@@ -17,7 +17,8 @@ using ABDataType = F16;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
2
,
8
>
;
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
2
,
8
>
;
...
@@ -37,19 +38,19 @@ void host_broadcast2D(
...
@@ -37,19 +38,19 @@ void host_broadcast2D(
{
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
ComputeDataType
Amn
=
static_cas
t
<
ComputeDataType
>
(
A
(
m
,
n
));
ComputeDataType
Amn
=
ck
::
type_conver
t
<
ComputeDataType
>
(
A
(
m
,
n
));
ComputeDataType
Cmn
=
0
;
ComputeDataType
Cmn
=
0
;
if
constexpr
(
broadcastDim
==
0
)
if
constexpr
(
broadcastDim
==
0
)
{
{
ComputeDataType
Bn
=
static_cas
t
<
ComputeDataType
>
(
B
(
n
));
ComputeDataType
Bn
=
ck
::
type_conver
t
<
ComputeDataType
>
(
B
(
n
));
functor
(
Cmn
,
Amn
,
Bn
);
functor
(
Cmn
,
Amn
,
Bn
);
}
}
else
else
{
{
ComputeDataType
Bm
=
static_cas
t
<
ComputeDataType
>
(
B
(
m
));
ComputeDataType
Bm
=
ck
::
type_conver
t
<
ComputeDataType
>
(
B
(
m
));
functor
(
Cmn
,
Amn
,
Bm
);
functor
(
Cmn
,
Amn
,
Bm
);
}
}
C
(
m
,
n
)
=
static_cas
t
<
ctype
>
(
Cmn
);
C
(
m
,
n
)
=
ck
::
type_conver
t
<
ctype
>
(
Cmn
);
}
}
}
}
}
}
...
...
example/19_binary_elementwise/elementwise_add_1d.cpp
View file @
d478a389
...
@@ -17,7 +17,8 @@ using ABDataType = F16;
...
@@ -17,7 +17,8 @@ using ABDataType = F16;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
1
,
8
>
;
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
1
,
8
>
;
...
@@ -34,11 +35,11 @@ void host_elementwise1D(
...
@@ -34,11 +35,11 @@ void host_elementwise1D(
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
{
ComputeDataType
Am
=
static_cas
t
<
ComputeDataType
>
(
A
(
m
));
ComputeDataType
Am
=
ck
::
type_conver
t
<
ComputeDataType
>
(
A
(
m
));
ComputeDataType
Bm
=
static_cas
t
<
ComputeDataType
>
(
B
(
m
));
ComputeDataType
Bm
=
ck
::
type_conver
t
<
ComputeDataType
>
(
B
(
m
));
ComputeDataType
Cm
=
0
;
ComputeDataType
Cm
=
0
;
functor
(
Cm
,
Am
,
Bm
);
functor
(
Cm
,
Am
,
Bm
);
C
(
m
)
=
static_cas
t
<
ctype
>
(
Cm
);
C
(
m
)
=
ck
::
type_conver
t
<
ctype
>
(
Cm
);
}
}
}
}
...
...
example/19_binary_elementwise/elementwise_add_4d.cpp
View file @
d478a389
...
@@ -17,7 +17,8 @@ using ABDataType = F16;
...
@@ -17,7 +17,8 @@ using ABDataType = F16;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
4
,
8
>
;
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
4
,
8
>
;
...
@@ -40,11 +41,11 @@ void host_elementwise4D(HostTensorC& C,
...
@@ -40,11 +41,11 @@ void host_elementwise4D(HostTensorC& C,
for
(
std
::
size_t
h
=
0
;
h
<
shape
[
2
];
++
h
)
for
(
std
::
size_t
h
=
0
;
h
<
shape
[
2
];
++
h
)
for
(
std
::
size_t
w
=
0
;
w
<
shape
[
3
];
++
w
)
for
(
std
::
size_t
w
=
0
;
w
<
shape
[
3
];
++
w
)
{
{
ComputeDataType
a_val
=
static_cas
t
<
ComputeDataType
>
(
A
(
n
,
c
,
h
,
w
));
ComputeDataType
a_val
=
ck
::
type_conver
t
<
ComputeDataType
>
(
A
(
n
,
c
,
h
,
w
));
ComputeDataType
b_val
=
static_cas
t
<
ComputeDataType
>
(
B
(
n
,
c
,
h
,
w
));
ComputeDataType
b_val
=
ck
::
type_conver
t
<
ComputeDataType
>
(
B
(
n
,
c
,
h
,
w
));
ComputeDataType
c_val
=
0
;
ComputeDataType
c_val
=
0
;
functor
(
c_val
,
a_val
,
b_val
);
functor
(
c_val
,
a_val
,
b_val
);
C
(
n
,
c
,
h
,
w
)
=
static_cas
t
<
ctype
>
(
c_val
);
C
(
n
,
c
,
h
,
w
)
=
ck
::
type_conver
t
<
ctype
>
(
c_val
);
}
}
}
}
...
...
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
d478a389
...
@@ -523,8 +523,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -523,8 +523,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
Add
=
using
Substract
=
ck
::
tensor_operation
::
binary_element_wise
::
Substract
;
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
CDataType
,
CDataType
,
CDataType
>
;
using
Substract
=
ck
::
tensor_operation
::
binary_element_wise
::
Substract
<
CDataType
,
CDataType
,
CDataType
>
;
using
GridwiseBinAdd
=
GridwiseBinaryElementwise_1D
<
CDataType
,
using
GridwiseBinAdd
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CDataType
,
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
d478a389
...
@@ -5,26 +5,42 @@ namespace ck {
...
@@ -5,26 +5,42 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
binary_element_wise
{
namespace
binary_element_wise
{
struct
Add
template
<
typename
Y
,
typename
X1
,
typename
X2
>
struct
Add
;
template
<
>
struct
Add
<
double
,
double
,
double
>
{
{
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
{
{
dst
=
src1
+
src2
;
dst
=
src1
+
src2
;
}
}
};
template
<
>
struct
Add
<
float
,
float
,
float
>
{
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
{
{
dst
=
src1
+
src2
;
dst
=
src1
+
src2
;
}
}
};
template
<
>
struct
Add
<
half_t
,
half_t
,
half_t
>
{
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
operator
()(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
{
{
dst
=
src1
+
src2
;
dst
=
src1
+
src2
;
}
}
};
template
<
>
struct
Add
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
{
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
{
{
...
@@ -35,26 +51,42 @@ struct Add
...
@@ -35,26 +51,42 @@ struct Add
}
}
};
};
struct
Substract
template
<
typename
Y
,
typename
X1
,
typename
X2
>
struct
Substract
;
template
<
>
struct
Substract
<
double
,
double
,
double
>
{
{
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
{
{
dst
=
src1
-
src2
;
dst
=
src1
-
src2
;
}
}
};
template
<
>
struct
Substract
<
float
,
float
,
float
>
{
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
{
{
dst
=
src1
-
src2
;
dst
=
src1
-
src2
;
}
}
};
template
<
>
struct
Substract
<
half_t
,
half_t
,
half_t
>
{
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
operator
()(
half_t
&
dst
,
const
half_t
&
src1
,
const
half_t
&
src2
)
const
{
{
dst
=
src1
-
src2
;
dst
=
src1
-
src2
;
}
}
};
template
<
>
struct
Substract
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
{
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment