Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a7676df9
Commit
a7676df9
authored
May 19, 2022
by
myamlak
Browse files
Merge remote-tracking branch 'origin/develop' into myamlak/cgemm
parents
6ebcb667
aafc3ac2
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
119 additions
and
106 deletions
+119
-106
example/19_binary_elementwise/broadcast_add_2d.cpp
example/19_binary_elementwise/broadcast_add_2d.cpp
+4
-8
example/19_binary_elementwise/elementwise_add_1d.cpp
example/19_binary_elementwise/elementwise_add_1d.cpp
+4
-8
example/19_binary_elementwise/elementwise_add_4d.cpp
example/19_binary_elementwise/elementwise_add_4d.cpp
+14
-17
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
...tensor_operation/gpu/device/device_binary_elementwise.hpp
+35
-39
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
..._operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
+16
-16
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+12
-1
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
...sor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
+17
-17
library/include/ck/library/host_tensor/host_utility.hpp
library/include/ck/library/host_tensor/host_utility.hpp
+17
-0
No files found.
example/19_binary_elementwise/broadcast_add_2d.cpp
View file @
a7676df9
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <math.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
...
...
@@ -13,7 +8,6 @@
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using
F16
=
ck
::
half_t
;
...
...
@@ -26,7 +20,7 @@ using EltwiseComputeDataType = F32;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
F16
,
F16
,
CDataType
,
EltwiseComputeDataType
,
Add
,
2
,
8
>
;
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
2
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
...
...
@@ -37,6 +31,8 @@ template <typename HostTensorA,
void
host_broadcast2D
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
int
M
,
int
N
,
Functor
functor
)
{
using
ctype
=
ck
::
remove_reference_t
<
decltype
(
C
(
0
,
0
))
>
;
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
...
...
@@ -53,7 +49,7 @@ void host_broadcast2D(
ComputeDataType
Bm
=
static_cast
<
ComputeDataType
>
(
B
(
m
));
functor
(
Cmn
,
Amn
,
Bm
);
}
C
(
m
,
n
)
=
static_cast
<
ComputeDataT
ype
>
(
Cmn
);
C
(
m
,
n
)
=
static_cast
<
ct
ype
>
(
Cmn
);
}
}
}
...
...
example/19_binary_elementwise/elementwise_add_1d.cpp
View file @
a7676df9
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <math.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
...
...
@@ -13,7 +8,6 @@
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using
F16
=
ck
::
half_t
;
...
...
@@ -26,7 +20,7 @@ using EltwiseComputeDataType = F32;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
F16
,
F16
,
CDataType
,
EltwiseComputeDataType
,
Add
,
1
,
8
>
;
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
1
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
...
...
@@ -36,13 +30,15 @@ template <typename HostTensorA,
void
host_elementwise1D
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
int
M
,
Functor
functor
)
{
using
ctype
=
ck
::
remove_reference_t
<
decltype
(
C
(
0
))
>
;
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
ComputeDataType
Am
=
static_cast
<
ComputeDataType
>
(
A
(
m
));
ComputeDataType
Bm
=
static_cast
<
ComputeDataType
>
(
B
(
m
));
ComputeDataType
Cm
=
0
;
functor
(
Cm
,
Am
,
Bm
);
C
(
m
)
=
static_cast
<
ComputeDataT
ype
>
(
Cm
);
C
(
m
)
=
static_cast
<
ct
ype
>
(
Cm
);
}
}
...
...
example/19_binary_elementwise/elementwise_add_4d.cpp
View file @
a7676df9
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <math.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_reduce_util.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_utility.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using
F16
=
ck
::
half_t
;
...
...
@@ -27,7 +21,7 @@ using EltwiseComputeDataType = F32;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
F16
,
F16
,
CDataType
,
EltwiseComputeDataType
,
Add
,
4
,
8
>
;
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
4
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
...
...
@@ -40,6 +34,8 @@ void host_elementwise4D(HostTensorC& C,
const
std
::
vector
<
std
::
size_t
>&
shape
,
Functor
functor
)
{
using
ctype
=
ck
::
remove_reference_t
<
decltype
(
C
(
0
,
0
,
0
,
0
))
>
;
for
(
std
::
size_t
n
=
0
;
n
<
shape
[
0
];
++
n
)
for
(
std
::
size_t
c
=
0
;
c
<
shape
[
1
];
++
c
)
for
(
std
::
size_t
h
=
0
;
h
<
shape
[
2
];
++
h
)
...
...
@@ -49,7 +45,7 @@ void host_elementwise4D(HostTensorC& C,
ComputeDataType
b_val
=
static_cast
<
ComputeDataType
>
(
B
(
n
,
c
,
h
,
w
));
ComputeDataType
c_val
=
0
;
functor
(
c_val
,
a_val
,
b_val
);
C
(
n
,
c
,
h
,
w
)
=
static_cast
<
ComputeDataT
ype
>
(
c_val
);
C
(
n
,
c
,
h
,
w
)
=
static_cast
<
ct
ype
>
(
c_val
);
}
}
...
...
@@ -75,14 +71,15 @@ int main()
b_m_device_buf
.
ToDevice
(
b_m
.
mData
.
data
());
auto
broadcastAdd
=
DeviceElementwiseAddInstance
{};
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_m_device_buf
.
GetDeviceBuffer
(),
b_m_device_buf
.
GetDeviceBuffer
(),
c_m_device_buf
.
GetDeviceBuffer
(),
ck
::
to_int_vector
(
nchw
),
ck
::
to_int_vector
(
a_m
.
mDesc
.
GetStrides
()),
ck
::
to_int_vector
(
b_m
.
mDesc
.
GetStrides
()),
ck
::
to_int_vector
(
c_m
.
mDesc
.
GetStrides
()),
Add
{});
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_m_device_buf
.
GetDeviceBuffer
(),
b_m_device_buf
.
GetDeviceBuffer
(),
c_m_device_buf
.
GetDeviceBuffer
(),
ck
::
convert_vector_element_type
<
std
::
size_t
,
ck
::
index_t
>
(
nchw
),
ck
::
convert_vector_element_type
<
std
::
size_t
,
ck
::
index_t
>
(
a_m
.
mDesc
.
GetStrides
()),
ck
::
convert_vector_element_type
<
std
::
size_t
,
ck
::
index_t
>
(
b_m
.
mDesc
.
GetStrides
()),
ck
::
convert_vector_element_type
<
std
::
size_t
,
ck
::
index_t
>
(
c_m
.
mDesc
.
GetStrides
()),
Add
{});
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
...
...
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
View file @
a7676df9
...
...
@@ -19,18 +19,15 @@ template <typename ADataType,
index_t
ScalarPerVector
>
struct
DeviceBinaryElementwise
:
public
BaseOperator
{
DeviceBinaryElementwise
(
index_t
threadPerBlock
=
256
)
:
BaseOperator
(),
threadPerBlock_
(
threadPerBlock
)
{
}
DeviceBinaryElementwise
(
index_t
blockSize
=
256
)
:
BaseOperator
(),
blockSize_
(
blockSize
)
{}
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
typename
Desc_M0
>
static
auto
PadDescriptor_M0_1d
(
Desc_M0
desc_m0
,
index_t
gridSize
,
index_t
threadPerBlock
)
static
auto
PadDescriptor_M0_1d
(
Desc_M0
desc_m0
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
threadPerBlock
*
ScalarPerVector
;
const
index_t
loop_step
=
gridSize
*
blockSize
*
ScalarPerVector
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
...
...
@@ -40,10 +37,10 @@ struct DeviceBinaryElementwise : public BaseOperator
return
desc_m0_pad
;
}
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
,
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
in
dex_
t
>&
shape
,
const
std
::
vector
<
in
dex_
t
>&
stride
,
index_t
gridSize
,
index_t
threadPerBlock
)
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
shape
[
I
];
},
Number
<
Dim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
Dim
>
{});
...
...
@@ -60,10 +57,10 @@ struct DeviceBinaryElementwise : public BaseOperator
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
Dim
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M0_1d
(
desc_m0
,
gridSize
,
threadPerBlock
);
return
PadDescriptor_M0_1d
(
desc_m0
,
gridSize
,
blockSize
);
}
else
return
PadDescriptor_M0_1d
(
desc
,
gridSize
,
threadPerBlock
);
return
PadDescriptor_M0_1d
(
desc
,
gridSize
,
blockSize
);
}
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
({
1
,
1
},
{
1
,
1
},
1
,
1
));
...
...
@@ -80,26 +77,28 @@ struct DeviceBinaryElementwise : public BaseOperator
Argument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
stride_b
,
const
std
::
vector
<
int
>&
stride_c
,
const
std
::
vector
<
in
dex_
t
>&
shape
,
const
std
::
vector
<
in
dex_
t
>&
stride_a
,
const
std
::
vector
<
in
dex_
t
>&
stride_b
,
const
std
::
vector
<
in
dex_
t
>&
stride_c
,
ElementwiseFunctor
functor
,
index_t
threadPerBlock
)
index_t
blockSize
)
:
p_a_
(
p_a
),
p_b_
(
p_b
),
p_c_
(
p_c
),
shape_
(
shape
),
functor_
(
functor
),
gridSize_
(
120
)
// FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_a
,
gridSize_
,
threadPerBlock
);
b_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_b
,
gridSize_
,
threadPerBlock
);
c_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_c
,
gridSize_
,
threadPerBlock
);
a_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_a
,
gridSize_
,
blockSize
);
b_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_b
,
gridSize_
,
blockSize
);
c_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_c
,
gridSize_
,
blockSize
);
}
const
ADataType
*
p_a_
;
const
BDataType
*
p_b_
;
CDataType
*
p_c_
;
std
::
vector
<
int
>
shape_
;
GridDesc_M0
a_grid_desc_m0_
;
GridDesc_M0
b_grid_desc_m0_
;
GridDesc_M0
c_grid_desc_m0_
;
...
...
@@ -109,21 +108,21 @@ struct DeviceBinaryElementwise : public BaseOperator
struct
Invoker
:
public
BaseInvoker
{
Invoker
(
index_t
threadPerBlock
)
:
BaseInvoker
(),
threadPerBlock_
(
threadPerBlock
)
{}
Invoker
(
index_t
blockSize
)
:
BaseInvoker
(),
blockSize_
(
blockSize
)
{}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel
=
kernel_elementwise_1d
<
GridwiseBinEltwise
,
ADataType
,
BDataType
,
CDataType
,
GridDesc_M0
,
ElementwiseFunctor
>
;
const
auto
kernel
=
kernel_
binary_
elementwise_1d
<
GridwiseBinEltwise
,
ADataType
,
BDataType
,
CDataType
,
GridDesc_M0
,
ElementwiseFunctor
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
threadPerBlock
_
),
dim3
(
blockSize
_
),
0
,
arg
.
p_a_
,
arg
.
p_b_
,
...
...
@@ -142,7 +141,7 @@ struct DeviceBinaryElementwise : public BaseOperator
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
index_t
threadPerBlock
_
;
index_t
blockSize
_
;
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -152,10 +151,7 @@ struct DeviceBinaryElementwise : public BaseOperator
if
(
pArg
==
nullptr
)
return
false
;
// shape[0] * shape[1] * shape[2] * ...
const
auto
m0
=
pArg
->
c_grid_desc_m0_
.
GetLength
(
I0
);
if
(
m0
%
ScalarPerVector
!=
0
)
if
(
pArg
->
shape_
.
back
()
%
ScalarPerVector
!=
0
)
return
false
;
return
true
;
...
...
@@ -164,10 +160,10 @@ struct DeviceBinaryElementwise : public BaseOperator
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
std
::
vector
<
int
>
shape
,
std
::
vector
<
int
>
stride_a
,
std
::
vector
<
int
>
stride_b
,
std
::
vector
<
int
>
stride_c
,
std
::
vector
<
in
dex_
t
>
shape
,
std
::
vector
<
in
dex_
t
>
stride_a
,
std
::
vector
<
in
dex_
t
>
stride_b
,
std
::
vector
<
in
dex_
t
>
stride_c
,
ElementwiseFunctor
functor
)
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
@@ -178,12 +174,12 @@ struct DeviceBinaryElementwise : public BaseOperator
stride_b
,
stride_c
,
functor
,
threadPerBlock
_
);
blockSize
_
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
threadPerBlock
_
});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
blockSize
_
});
}
std
::
string
GetTypeString
()
const
override
...
...
@@ -200,7 +196,7 @@ struct DeviceBinaryElementwise : public BaseOperator
return
str
.
str
();
}
index_t
threadPerBlock
_
;
index_t
blockSize
_
;
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
a7676df9
...
...
@@ -71,10 +71,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static
constexpr
auto
ScalarPerVector
=
Number
<
4
>
{};
template
<
typename
Desc_M0
>
static
auto
PadDescriptor_M0_1d
(
Desc_M0
desc_m0
,
index_t
gridSize
,
index_t
threadPerBlock
)
static
auto
PadDescriptor_M0_1d
(
Desc_M0
desc_m0
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
threadPerBlock
*
ScalarPerVector
;
const
index_t
loop_step
=
gridSize
*
blockSize
*
ScalarPerVector
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
...
...
@@ -87,7 +87,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
,
index_t
gridSize
,
index_t
threadPerBlock
)
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
shape
[
I
];
},
Number
<
2
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
2
>
{});
...
...
@@ -100,7 +100,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
2
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M0_1d
(
desc_m0
,
gridSize
,
threadPerBlock
);
return
PadDescriptor_M0_1d
(
desc_m0
,
gridSize
,
blockSize
);
}
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
...
...
@@ -536,18 +536,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
GridDesc_M0
,
Substract
,
ScalarPerVector
>
;
const
auto
add_kernel
=
kernel_elementwise_1d
<
GridwiseBinAdd
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
Add
>
;
const
auto
substract_kernel
=
kernel_elementwise_1d
<
GridwiseBinSubstract
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
Substract
>
;
const
auto
add_kernel
=
kernel_
binary_
elementwise_1d
<
GridwiseBinAdd
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
Add
>
;
const
auto
substract_kernel
=
kernel_
binary_
elementwise_1d
<
GridwiseBinSubstract
,
CDataType
,
CDataType
,
CDataType
,
GridDesc_M0
,
Substract
>
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
a7676df9
...
...
@@ -7,6 +7,12 @@ namespace binary_element_wise {
struct
Add
{
__host__
__device__
constexpr
void
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
{
dst
=
src1
+
src2
;
}
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
{
...
...
@@ -32,6 +38,12 @@ struct Add
struct
Substract
{
__host__
__device__
constexpr
void
operator
()(
double
&
dst
,
const
double
&
src1
,
const
double
&
src2
)
const
{
dst
=
src1
-
src2
;
}
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
{
dst
=
src1
-
src2
;
...
...
@@ -43,7 +55,6 @@ struct Substract
dst
=
src1
-
src2
;
}
// TO FIX!!!
__host__
__device__
constexpr
void
operator
()(
bhalf_t
&
dst
,
const
bhalf_t
&
src1
,
const
bhalf_t
&
src2
)
const
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
View file @
a7676df9
...
...
@@ -13,13 +13,13 @@ template <typename GridwiseBinEltwise,
typename
CDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
>
__global__
void
kernel_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
GridDesc_M0
c_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
__global__
void
kernel_
binary_
elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
GridDesc_M0
c_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
GridwiseBinEltwise
::
Run
(
p_a_global
,
p_b_global
,
...
...
@@ -45,7 +45,7 @@ struct GridwiseBinaryElementwise_1D
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
__host__
auto
CalculateElementwiseIndex
()
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
ScalarPerVector
);
...
...
@@ -70,7 +70,7 @@ struct GridwiseBinaryElementwise_1D
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
ScalarPerVector
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
ScalarPerVector
,
true
>
c_thread_buf
;
const
auto
thread_to_global_offset
=
CalculateElementwiseIndex
();
const
auto
thread_
s
to
re
_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
...
...
@@ -82,7 +82,7 @@ struct GridwiseBinaryElementwise_1D
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m0
,
thread_to_global_offset
};
false
>
{
a_grid_desc_m0
,
thread_
s
to
re
_global_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
...
...
@@ -94,7 +94,7 @@ struct GridwiseBinaryElementwise_1D
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m0
,
thread_to_global_offset
};
false
>
{
b_grid_desc_m0
,
thread_
s
to
re
_global_offset
};
auto
c_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
...
...
@@ -109,13 +109,13 @@ struct GridwiseBinaryElementwise_1D
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
c_grid_desc_m0
,
thread_to_global_offset
,
PassThrough
{}};
c_grid_desc_m0
,
thread_
s
to
re
_global_offset
,
PassThrough
{}};
const
index_t
threadPerBlock
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
m0
=
c_grid_desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
threadPerBlock
*
ScalarPerVector
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
m0
=
c_grid_desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
ScalarPerVector
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
m0
/
(
loop_step
);
do
...
...
library/include/ck/library/host_tensor/host_utility.hpp
0 → 100644
View file @
a7676df9
#pragma once
#include <vector>
namespace
ck
{
template
<
typename
Src
,
typename
Dst
>
inline
std
::
vector
<
Dst
>
convert_vector_element_type
(
const
std
::
vector
<
Src
>&
inData
)
{
std
::
vector
<
Dst
>
outData
;
for
(
auto
elem
:
inData
)
outData
.
push_back
(
static_cast
<
Dst
>
(
elem
));
return
(
outData
);
};
};
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment