Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b571256f
Commit
b571256f
authored
May 30, 2022
by
ltqin
Browse files
fix bug after merge develop
parent
f9c478e2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
42 deletions
+45
-42
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
+6
-7
include/ck/tensor_operation/gpu/device/device_gemm_xdl_skip_b_lds.hpp
...ensor_operation/gpu/device/device_gemm_xdl_skip_b_lds.hpp
+8
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
+31
-28
No files found.
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
View file @
b571256f
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl_skip_b_lds.hpp"
#include "device_gemm_xdl_skip_b_lds.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
...
@@ -82,7 +81,7 @@ using AccDataType = float;
...
@@ -82,7 +81,7 @@ using AccDataType = float;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
float
,
AElementOp
,
BElementOp
,
CElementOp
>
;
template
<
typename
DataType
>
template
<
typename
DataType
>
std
::
ostream
&
show_2d_matrix
(
std
::
ostream
&
os
,
Tensor
<
DataType
>&
matrix
)
std
::
ostream
&
show_2d_matrix
(
std
::
ostream
&
os
,
Tensor
<
DataType
>&
matrix
)
...
@@ -104,7 +103,7 @@ int main(int argc, char* argv[])
...
@@ -104,7 +103,7 @@ int main(int argc, char* argv[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
0
;
int
init_method
=
0
;
int
init_method
=
0
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
#if 1
#if 1
...
@@ -129,13 +128,13 @@ int main(int argc, char* argv[])
...
@@ -129,13 +128,13 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
10
)
else
if
(
argc
==
10
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
...
@@ -149,7 +148,7 @@ int main(int argc, char* argv[])
...
@@ -149,7 +148,7 @@ int main(int argc, char* argv[])
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -228,7 +227,7 @@ int main(int argc, char* argv[])
...
@@ -228,7 +227,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
}
);
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_skip_b_lds.hpp
View file @
b571256f
...
@@ -283,7 +283,7 @@ struct DeviceGemmXdlSkipBLds
...
@@ -283,7 +283,7 @@ struct DeviceGemmXdlSkipBLds
{
{
using
Argument
=
DeviceGemmXdlSkipBLds
::
Argument
;
using
Argument
=
DeviceGemmXdlSkipBLds
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
{
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
...
@@ -332,8 +332,8 @@ struct DeviceGemmXdlSkipBLds
...
@@ -332,8 +332,8 @@ struct DeviceGemmXdlSkipBLds
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
nrepeat
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -364,8 +364,8 @@ struct DeviceGemmXdlSkipBLds
...
@@ -364,8 +364,8 @@ struct DeviceGemmXdlSkipBLds
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
nrepeat
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
@@ -385,9 +385,10 @@ struct DeviceGemmXdlSkipBLds
...
@@ -385,9 +385,10 @@ struct DeviceGemmXdlSkipBLds
}
}
// polymorphic
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
View file @
b571256f
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops_skip_b_lds.hpp"
#include "blockwise_gemm_xdlops_skip_b_lds.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
@@ -124,6 +124,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -124,6 +124,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
index_t
K0PerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
K0PerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -397,33 +399,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -397,33 +399,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
*
MultiK0
,
MPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
*
MultiK0
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
1
>
(
a_grid_desc_k0_m_k1
,
1
>
(
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_grid_desc_k0_m_k1
,
a_element_op
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_block_desc_k0_m_k1
,
a_element_op
,
make_multi_index
(
0
,
0
,
0
),
a_block_desc_k0_m_k1
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ignore
=
b_element_op
;
ignore
=
b_element_op
;
// B matrix threadwise copy
// B matrix threadwise copy
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment