Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e3a09b57
Commit
e3a09b57
authored
Apr 11, 2022
by
rocking
Browse files
Add gridwise_elementwise_2d api
parent
6818b58c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
169 additions
and
8 deletions
+169
-8
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
.../ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
+86
-7
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+82
-0
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
e3a09b57
...
@@ -125,7 +125,7 @@ struct Sub
...
@@ -125,7 +125,7 @@ struct Sub
};
};
using
DeviceElementwiseInstance
=
using
DeviceElementwiseInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
256
,
Sub
>
;
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Sub
,
16
,
16
,
8
,
8
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
...
...
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
View file @
e3a09b57
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include "device.hpp"
#include "device.hpp"
#include "device_elementwise.hpp"
#include "device_elementwise.hpp"
#include "gridwise_elementwise_2d.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -12,17 +13,36 @@ namespace device {
...
@@ -12,17 +13,36 @@ namespace device {
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
index_t
BlockSize
,
typename
ElementwiseFunctor
,
typename
ElementwiseFunctor
>
index_t
MThreadPerBlock
,
index_t
NThreadPerBlock
,
index_t
MThreadTileSize
,
index_t
NThreadTileSize
>
struct
DeviceElementwise_2D
:
public
DeviceElementwise
<
ElementwiseFunctor
>
struct
DeviceElementwise_2D
:
public
DeviceElementwise
<
ElementwiseFunctor
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
auto
Make2dDescriptor_M_N
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
)
static
auto
Make2dDescriptor_M_N
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
shape
[
0
],
shape
[
1
]),
return
make_naive_tensor_descriptor
(
make_tuple
(
shape
[
0
],
shape
[
1
]),
make_tuple
(
stride
[
0
],
stride
[
1
]));
make_tuple
(
stride
[
0
],
stride
[
1
]));
}
}
using
GridDesc_M_N
=
decltype
(
Make2dDescriptor_M_N
({
1
,
1
},
{
1
,
1
}));
static
constexpr
index_t
BlockSize
=
MThreadPerBlock
*
NThreadPerBlock
;
static
constexpr
int
M_BlockTileSize
=
MThreadPerBlock
*
MThreadTileSize
;
static
constexpr
int
N_BlockTileSize
=
NThreadPerBlock
*
NThreadTileSize
;
using
GridDesc_M_N
=
decltype
(
Make2dDescriptor_M_N
({
1
,
1
},
{
1
,
1
}));
using
GridwiseEltwise
=
GridwiseElementwise_2D
<
ADataType
,
BDataType
,
CDataType
,
GridDesc_M_N
,
GridDesc_M_N
,
GridDesc_M_N
,
ElementwiseFunctor
,
MThreadTileSize
,
NThreadTileSize
>
;
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -55,12 +75,63 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
...
@@ -55,12 +75,63 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
index_t
CalculateGridSize
(
const
GridDesc_M_N
&
grid_desc_m_n
)
{
const
auto
M
=
grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
grid_desc_m_n
.
GetLength
(
I1
);
assert
(
M
%
M_BlockTileSize
==
0
);
assert
(
N
%
N_BlockTileSize
==
0
);
return
(
M
/
M_BlockTileSize
)
*
(
N
/
N_BlockTileSize
);
}
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
const
auto
kernel
=
kernel_elementwise_2d
<
GridwiseEltwise
,
ADataType
,
BDataType
,
CDataType
,
GridDesc_M_N
,
GridDesc_M_N
,
GridDesc_M_N
,
ElementwiseFunctor
>
;
// TODO
// TODO
(
void
)
arg
;
(
void
)
arg
;
(
void
)
nrepeat
;
(
void
)
nrepeat
;
return
0
;
(
void
)
kernel
;
float
avgTime
=
0
;
const
index_t
gridSize
=
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
if
(
nrepeat
==
0
)
{
launch_kernel
(
kernel
,
dim3
(
gridSize
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
p_c_
,
arg
.
a_grid_desc_m_n_
,
arg
.
b_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
functor_
);
}
else
{
avgTime
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
gridSize
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
p_c_
,
arg
.
a_grid_desc_m_n_
,
arg
.
b_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
functor_
);
}
return
avgTime
;
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
...
@@ -71,9 +142,18 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
...
@@ -71,9 +142,18 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
{
// TODO: properly implement this check
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
return
pArg
!=
nullptr
;
if
(
pArg
==
nullptr
)
return
false
;
const
auto
M
=
pArg
->
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
auto
N
=
pArg
->
c_grid_desc_m_n_
.
GetLength
(
I1
);
if
(
M
%
M_BlockTileSize
!=
0
&&
N
%
N_BlockTileSize
!=
0
)
return
false
;
return
true
;
};
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
...
@@ -107,7 +187,6 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
...
@@ -107,7 +187,6 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
// clang-format off
// clang-format off
str
<<
"DeviceElementwise_2D"
str
<<
"DeviceElementwise_2D"
<<
"<"
<<
"<"
<<
BlockSize
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
0 → 100644
View file @
e3a09b57
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AGridDesc_M_N
,
typename
BGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
ElementwiseFunctor
>
__global__
void
kernel_elementwise_2d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
AGridDesc_M_N
a_grid_desc_m_k
,
const
BGridDesc_M_N
b_grid_desc_m_k
,
const
CGridDesc_M_N
c_grid_desc_m_k
,
const
ElementwiseFunctor
functor
)
{
GridwiseEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_grid_desc_m_k
,
b_grid_desc_m_k
,
c_grid_desc_m_k
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AGridDesc_M_N
,
typename
BGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
ElementwiseFunctor
,
index_t
MThreadTileSize
,
index_t
NThreadTileSize
>
struct
GridwiseElementwise_2D
{
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
AGridDesc_M_N
a_grid_desc_m_n
,
const
BGridDesc_M_N
b_grid_desc_m_n
,
const
CGridDesc_M_N
c_grid_desc_m_n
,
const
ElementwiseFunctor
functor
)
{
// const index_t thread_id = get_thread_local_1d_id();
// const index_t block_id = get_block_1d_id();
// printf("block_id = %d, thread_id = %d \n", block_id, thread_id);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m_n
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
,
MThreadTileSize
*
NThreadTileSize
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
,
MThreadTileSize
*
NThreadTileSize
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
CDataType
,
MThreadTileSize
*
NThreadTileSize
,
true
>
c_thread_buf
;
// TODO - buffer_load, apply functor, buffer_store
(
void
)
a_global_buf
;
(
void
)
b_global_buf
;
(
void
)
c_global_buf
;
(
void
)
a_thread_buf
;
(
void
)
b_thread_buf
;
(
void
)
c_thread_buf
;
(
void
)
functor
;
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment