Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c5b5d2e4
Commit
c5b5d2e4
authored
May 08, 2022
by
Chao Liu
Browse files
clean up
parent
9685fed2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
42 additions
and
99 deletions
+42
-99
include/ck/tensor_operation/gpu/device/batched_gemm_util.hpp
include/ck/tensor_operation/gpu/device/batched_gemm_util.hpp
+0
-79
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+0
-3
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
+0
-3
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
...ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
+3
-9
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+39
-5
No files found.
include/ck/tensor_operation/gpu/device/batched_gemm_util.hpp
deleted
100644 → 0
View file @
9685fed2
#pragma once
#include "tuple.hpp"
#include "tensor_adaptor.hpp"
#include "multi_index_transform_helper.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
BatchedGemmUtil
{
template
<
index_t
MPerBlock
,
index_t
NPerBlock
>
static
constexpr
auto
MakeBlock2CTileMap
(
index_t
batch_count
,
index_t
M
,
index_t
N
,
index_t
M01
=
1
,
index_t
N01
=
1
)
{
constexpr
auto
M1
=
MPerBlock
;
constexpr
auto
N1
=
NPerBlock
;
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_insert_transform
(
batch_count
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
batch_count
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
globalblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
globalblockid_to_m0_n0_block_cluster_adaptor
;
}
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideC_
(
BatchStrideC
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
private:
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
c5b5d2e4
...
@@ -369,9 +369,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -369,9 +369,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
View file @
c5b5d2e4
...
@@ -187,9 +187,6 @@ struct DeviceGemmXdl
...
@@ -187,9 +187,6 @@ struct DeviceGemmXdl
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
// AGridDesc_K0_M_K1,
// BGridDesc_K0_N_K1,
// CGridDesc_M_N,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
View file @
c5b5d2e4
...
@@ -299,14 +299,9 @@ struct DeviceGemmXdlSplitK
...
@@ -299,14 +299,9 @@ struct DeviceGemmXdlSplitK
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
static
constexpr
auto
MakeBlock2CTileMap
(
index_t
batch_count
,
static
constexpr
auto
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
MakeBlock2CTileMap
(
index_t
batch_count
,
index_t
M
,
index_t
N
,
index_t
M01
,
index_t
N01
)
index_t
M01
,
index_t
N01
)
{
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
...
@@ -363,7 +358,6 @@ struct DeviceGemmXdlSplitK
...
@@ -363,7 +358,6 @@ struct DeviceGemmXdlSplitK
private:
private:
index_t
BatchStrideA_
;
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideB_
;
// index_t BatchStrideC_; // always zero
};
};
using
GridwiseGemm
=
using
GridwiseGemm
=
...
@@ -408,7 +402,7 @@ struct DeviceGemmXdlSplitK
...
@@ -408,7 +402,7 @@ struct DeviceGemmXdlSplitK
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeBlock2CTileMap
(
1
,
CGridDesc_M_N
{}
,
1
,
1
));
using
Block2CTileMap
=
decltype
(
MakeBlock2CTileMap
(
1
,
1
,
1
,
1
,
1
));
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
c5b5d2e4
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
#include "batched_gemm_util.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -370,6 +369,39 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -370,6 +369,39 @@ struct DeviceGemmXdlSplitKCShuffle
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
static
constexpr
auto
MakeBlock2CTileMap
(
index_t
batch_count
,
index_t
M
,
index_t
N
,
index_t
M01
,
index_t
N01
)
{
constexpr
auto
M1
=
MPerBlock
;
constexpr
auto
N1
=
NPerBlock
;
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_insert_transform
(
batch_count
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
batch_count
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
globalblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
globalblockid_to_m0_n0_block_cluster_adaptor
;
}
struct
ComputePtrOffsetOfStridedBatch
struct
ComputePtrOffsetOfStridedBatch
{
{
ComputePtrOffsetOfStridedBatch
(
const
index_t
BatchStrideA
,
const
index_t
BatchStrideB
)
ComputePtrOffsetOfStridedBatch
(
const
index_t
BatchStrideA
,
const
index_t
BatchStrideB
)
...
@@ -443,8 +475,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -443,8 +475,7 @@ struct DeviceGemmXdlSplitKCShuffle
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
using
Block2CTileMap
=
decltype
(
MakeBlock2CTileMap
(
1
,
1
,
1
,
1
,
1
));
decltype
(
BatchedGemmUtil
::
MakeBlock2CTileMap
<
MPerBlock
,
NPerBlock
>
(
1
,
1
,
1
));
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -540,8 +571,11 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -540,8 +571,11 @@ struct DeviceGemmXdlSplitKCShuffle
compute_ptr_offset_of_batch_
=
compute_ptr_offset_of_batch_
=
ComputePtrOffsetOfStridedBatch
{
a_batch_stride
,
b_batch_stride
};
ComputePtrOffsetOfStridedBatch
{
a_batch_stride
,
b_batch_stride
};
block_2_ctile_map_
=
BatchedGemmUtil
::
MakeBlock2CTileMap
<
MPerBlock
,
NPerBlock
>
(
block_2_ctile_map_
=
MakeBlock2CTileMap
(
BatchCount_
,
BatchCount_
,
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
));
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
),
1
,
1
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment