Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5d36f7a2
Commit
5d36f7a2
authored
Apr 21, 2022
by
rocking
Browse files
Rewrite the elementwise operation.
Let memory coalesce between block
parent
88d621ac
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
60 additions
and
59 deletions
+60
-59
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+15
-10
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
...tensor_operation/gpu/device/device_binary_elementwise.hpp
+8
-9
include/ck/tensor_operation/gpu/device/device_binary_elementwise_2d.hpp
...sor_operation/gpu/device/device_binary_elementwise_2d.hpp
+20
-27
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
...sor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
+13
-13
include/ck/utility/get_id.hpp
include/ck/utility/get_id.hpp
+4
-0
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
5d36f7a2
...
...
@@ -171,16 +171,21 @@ struct Div
using
DeviceElementwiseSubExpInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
EltwiseComputeDataType
,
Sub_Exp
,
256
,
32
,
8
>
;
using
DeviceElementwiseDivInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
EltwiseComputeDataType
,
Div
,
256
,
32
,
8
>
;
CDataType
,
CDataType
,
EltwiseComputeDataType
,
Sub_Exp
,
256
,
8
>
;
using
DeviceElementwiseDivInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
EltwiseComputeDataType
,
Div
,
256
,
8
>
;
using
HostGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
...
...
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
View file @
5d36f7a2
...
...
@@ -12,15 +12,14 @@ template <typename ElementwiseFunctor>
struct
DeviceBinaryElementwise
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
std
::
vector
<
int
>&
shape_a
,
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
shape_b
,
const
std
::
vector
<
int
>&
stride_b
,
ElementwiseFunctor
functor
)
=
0
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
std
::
vector
<
int
>&
shape_a
,
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
shape_b
,
const
std
::
vector
<
int
>&
stride_b
,
ElementwiseFunctor
functor
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
...
...
include/ck/tensor_operation/gpu/device/device_binary_elementwise_2d.hpp
View file @
5d36f7a2
...
...
@@ -16,15 +16,14 @@ template <typename ADataType,
typename
ComputeDataType
,
typename
ElementwiseFunctor
,
index_t
ThreadPerBlock
,
index_t
ThreadTileSize
,
index_t
ScalarPerVector
>
struct
DeviceBinaryElementwise_2D
:
public
DeviceBinaryElementwise
<
ElementwiseFunctor
>
{
static_assert
(
ThreadTileSize
%
ScalarPerVector
==
0
);
static
constexpr
int
BlockTileSize
=
ThreadPerBlock
*
ThreadTileSize
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
)
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
,
index_t
gridSize
)
{
const
int
m
=
shape
[
0
];
const
int
n
=
shape
[
1
];
...
...
@@ -41,8 +40,9 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
make_tuple
(
Sequence
<
0
>
{}));
// pad
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
BlockTileSize
)
-
m0
;
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
ThreadPerBlock
*
ScalarPerVector
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
make_tuple
(
make_right_pad_transform
(
m0
,
pad
)),
...
...
@@ -51,15 +51,13 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
return
desc_m0_pad
;
}
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
({
1
,
1
},
{
1
,
1
}));
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
({
1
,
1
},
{
1
,
1
}
,
1
));
using
GridwiseBinEltwise
=
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
ComputeDataType
,
GridDesc_M0
,
ElementwiseFunctor
,
ThreadPerBlock
,
ThreadTileSize
,
ScalarPerVector
>
;
struct
Argument
:
public
BaseArgument
...
...
@@ -75,11 +73,12 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
:
p_a_
(
p_a
),
p_b_
(
p_b
),
p_c_
(
p_c
),
a_grid_desc_m0_
(
MakeDescriptor_M0
(
shape
,
stride_a
)),
b_grid_desc_m0_
(
MakeDescriptor_M0
(
shape
,
stride_b
)),
c_grid_desc_m0_
(
MakeDescriptor_M0
(
shape
,
stride_c
)),
functor_
(
functor
)
functor_
(
functor
),
gridSize_
(
128
)
// FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_a
,
gridSize_
);
b_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_b
,
gridSize_
);
c_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_c
,
gridSize_
);
}
const
ADataType
*
p_a_
;
...
...
@@ -89,30 +88,25 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
GridDesc_M0
b_grid_desc_m0_
;
GridDesc_M0
c_grid_desc_m0_
;
ElementwiseFunctor
functor_
;
index_t
gridSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
index_t
CalculateGridSize
(
const
GridDesc_M0
&
grid_desc_m0
)
{
const
auto
gridTileSize
=
grid_desc_m0
.
GetLength
(
I0
);
return
gridTileSize
/
BlockTileSize
;
}
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
const
auto
kernel
=
kernel_elementwise_1d
<
GridwiseBinEltwise
,
(
void
)
arg
;
const
auto
kernel
=
kernel_elementwise_1d
<
GridwiseBinEltwise
,
ADataType
,
BDataType
,
CDataType
,
GridDesc_M0
,
ElementwiseFunctor
>
;
float
avgTime
=
0
;
const
index_t
gridSize
=
CalculateGridSize
(
arg
.
c_grid_desc_m0_
);
float
avgTime
=
0
;
if
(
nrepeat
==
0
)
{
launch_kernel
(
kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
gridSize
_
),
dim3
(
ThreadPerBlock
),
0
,
arg
.
p_a_
,
...
...
@@ -127,7 +121,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
{
avgTime
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
gridSize
),
dim3
(
arg
.
gridSize
_
),
dim3
(
ThreadPerBlock
),
0
,
arg
.
p_a_
,
...
...
@@ -157,7 +151,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
// m * n
const
auto
m0
=
pArg
->
c_grid_desc_m0_
.
GetLength
(
I0
);
if
(
m0
%
BlockTileSize
!=
0
)
if
(
m0
%
ScalarPerVector
!=
0
)
return
false
;
return
true
;
...
...
@@ -195,7 +189,6 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
str
<<
"DeviceBinaryElementwise_2D"
<<
"<"
<<
"ThreadPerBlock = "
<<
ThreadPerBlock
<<
"ThreadTileSize = "
<<
ThreadTileSize
<<
"ScalarPerVector = "
<<
ScalarPerVector
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
View file @
5d36f7a2
...
...
@@ -36,13 +36,10 @@ template <typename ADataType,
typename
ComputeDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
,
index_t
ThreadPerBlock
,
index_t
ThreadTileSize
,
index_t
ScalarPerVector
>
struct
GridwiseBinaryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
int
BlockTileSize
=
ThreadPerBlock
*
ThreadTileSize
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_M0
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScalarPerVector
>
{}));
...
...
@@ -50,10 +47,8 @@ struct GridwiseBinaryElementwise_1D
static
__device__
__host__
auto
CalculateElementwiseIndex
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
return
make_multi_index
(
block_id
*
BlockTileSize
+
thread_id
*
ScalarPerVector
);
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
ScalarPerVector
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
...
...
@@ -116,8 +111,13 @@ struct GridwiseBinaryElementwise_1D
false
>
{
c_grid_desc_m0
,
thread_to_global_offset
,
PassThrough
{}};
int
num_iter
=
ThreadTileSize
/
ScalarPerVector
;
constexpr
auto
thread_to_global_step
=
make_multi_index
(
ThreadPerBlock
*
ScalarPerVector
);
const
index_t
threadPerBlock
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
m0
=
c_grid_desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
threadPerBlock
*
ScalarPerVector
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
m0
/
(
loop_step
);
do
{
// read and process ScalarPerVector elements
...
...
@@ -140,9 +140,9 @@ struct GridwiseBinaryElementwise_1D
c_grid_desc_m0
,
c_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m0
,
thread_to_global_step
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m0
,
thread_to_global_step
);
c_global_write
.
MoveDstSliceWindow
(
c_grid_desc_m0
,
thread_to_global_step
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m0
,
loop_step_index
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m0
,
loop_step_index
);
c_global_write
.
MoveDstSliceWindow
(
c_grid_desc_m0
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
...
...
include/ck/utility/get_id.hpp
View file @
5d36f7a2
...
...
@@ -7,10 +7,14 @@ __device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; }
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_global_1d_id
()
{
return
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
}
__device__
index_t
get_wave_local_1d_id
()
{
return
threadIdx
.
x
/
get_wave_size
();
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_grid_size
()
{
return
gridDim
.
x
;
}
__device__
index_t
get_block_size
()
{
return
blockDim
.
x
;
}
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment