Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a41f5481
Commit
a41f5481
authored
Apr 25, 2022
by
rocking
Browse files
1. Fix coding style
2. Use DeviceGemm_Xdl_CShuffle instead of deprecated DeviceGemmXdl_C_Shuffle
parent
680cfaa7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
19 deletions
+23
-19
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+15
-11
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
...sor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
+8
-8
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
a41f5481
...
...
@@ -15,7 +15,7 @@
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c
_
shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
...
...
@@ -50,19 +50,23 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_C_Shuffle
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
<
ALayout
,
// ALayout
BLayout
,
// BLayout
CLayout
,
// CLayout
ADataType
,
// ADataType
BDataType
,
// BDataType
CDataType
,
// CDataType
AccDataType
,
// AccDataType
CDataType
,
// CShuffleDataType
ALayout
,
// ALayout
BLayout
,
// BLayout
CLayout
,
// CLayout
PassThrough
,
// AElementwiseOperation
PassThrough
,
// BElementwiseOperation
PassThrough
,
// CElementwiseOperation
GemmDefault
,
// GemmSpec
1
,
// NumGemmKPrefetchStage
256
,
// BlockSize
256
,
// MPerBlock
128
,
// NPerBlock
...
...
@@ -89,7 +93,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
// C
BlockTransferClusterLengths_MBlock_M
XdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
S
<
1
,
32
,
1
,
8
>
,
// CShuffle
BlockTransferClusterLengths_MBlock_M
PerBlock_NBlock_NPerBlock
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
...
...
@@ -149,7 +153,7 @@ using DeviceReduceSumInstance =
1
,
1
>
;
struct
Sub
_
Exp
struct
SubExp
{
__host__
__device__
constexpr
void
operator
()(
EltwiseComputeDataType
&
dst
,
const
EltwiseComputeDataType
&
src1
,
...
...
@@ -174,7 +178,7 @@ using DeviceElementwiseSubExpInstance =
CDataType
,
CDataType
,
EltwiseComputeDataType
,
Sub
_
Exp
,
SubExp
,
256
,
8
>
;
...
...
@@ -412,7 +416,7 @@ int main(int argc, char* argv[])
{
StrideC
,
1
},
{
0
,
1
},
{
StrideC
,
1
},
Sub
_
Exp
{});
SubExp
{});
if
(
!
broadcastSubExp
.
IsSupportedArgument
(
broadcastSubExp_argument_ptr
.
get
()))
{
...
...
@@ -515,8 +519,8 @@ int main(int argc, char* argv[])
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
Sub
_
Exp
,
0
>
(
host_exp_m_n
,
c_m_n
,
c_n_max
,
M
,
N
,
Sub
_
Exp
{});
SubExp
,
0
>
(
host_exp_m_n
,
c_m_n
,
c_n_max
,
M
,
N
,
SubExp
{});
host_reduce_sum
.
Run
(
1
,
// alpha
reinterpret_cast
<
const
HostReduceDataType
*>
(
exp_m_n
.
mData
.
data
()),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
View file @
a41f5481
...
...
@@ -40,7 +40,7 @@ template <typename ADataType,
struct
GridwiseBinaryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_
M
0
=
static
constexpr
auto
thread_desc_
m
0
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScalarPerVector
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -76,7 +76,7 @@ struct GridwiseBinaryElementwise_1D
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ComputeDataType
,
GridDesc_M0
,
decltype
(
thread_desc_
M
0
),
decltype
(
thread_desc_
m
0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
...
...
@@ -88,7 +88,7 @@ struct GridwiseBinaryElementwise_1D
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
ComputeDataType
,
GridDesc_M0
,
decltype
(
thread_desc_
M
0
),
decltype
(
thread_desc_
m
0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
...
...
@@ -99,7 +99,7 @@ struct GridwiseBinaryElementwise_1D
auto
c_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
CDataType
,
decltype
(
thread_desc_
M
0
),
decltype
(
thread_desc_
m
0
),
GridDesc_M0
,
PassThrough
,
Sequence
<
ScalarPerVector
>
,
// SliceLengths
...
...
@@ -122,19 +122,19 @@ struct GridwiseBinaryElementwise_1D
{
// read and process ScalarPerVector elements
a_global_load
.
Run
(
a_grid_desc_m0
,
a_global_buf
,
thread_desc_
M
0
,
make_tuple
(
I0
),
a_thread_buf
);
a_grid_desc_m0
,
a_global_buf
,
thread_desc_
m
0
,
make_tuple
(
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m0
,
b_global_buf
,
thread_desc_
M
0
,
make_tuple
(
I0
),
b_thread_buf
);
b_grid_desc_m0
,
b_global_buf
,
thread_desc_
m
0
,
make_tuple
(
I0
),
b_thread_buf
);
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_
M
0
.
CalculateOffset
(
make_tuple
(
m
));
constexpr
auto
offset
=
thread_desc_
m
0
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
c_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}));
});
c_global_write
.
Run
(
thread_desc_
M
0
,
c_global_write
.
Run
(
thread_desc_
m
0
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
c_thread_buf
,
c_grid_desc_m0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment