Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6a9d7b64
Commit
6a9d7b64
authored
Feb 27, 2023
by
aska-0096
Browse files
temp save
parent
d4adc71a
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
312 additions
and
145 deletions
+312
-145
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+6
-6
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+40
-20
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+223
-90
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+3
-5
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+15
-21
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+20
-0
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
6a9d7b64
...
@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault
,
GemmDefault
,
256
,
// BlockSize
256
,
// BlockSize
128
,
// MPerBlock
128
,
// MPerBlock
1
28
,
// NPerBlock
1
6
,
// NPerBlock
64
,
// KPerBlock
32
,
// KPerBlock
8
,
// K1
8
,
// K1
16
,
// MPerWmma
16
,
// MPerWmma
16
,
// NPerWmma
16
,
// NPerWmma
1
,
// M Repeat
1
,
// M Repeat
8
,
// N-Repeat
1
,
// N-Repeat
S
<
4
,
64
,
1
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8
,
8
,
8
,
8
,
true
,
true
,
S
<
4
,
6
4
,
1
>
,
S
<
4
,
1
6
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
...
@@ -59,8 +59,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -59,8 +59,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8
,
8
,
true
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (M Repeat) Per store
4
,
// C shuffle (N Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
64
,
1
,
4
>
,
S
<
1
,
128
,
1
,
2
>
,
8
>
;
8
>
;
// clang-format on
// clang-format on
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
6a9d7b64
...
@@ -94,12 +94,14 @@ using DeviceGemmInstance =
...
@@ -94,12 +94,14 @@ using DeviceGemmInstance =
TensorSpecB1
,
TensorSpecB1
,
TensorSpecC
,
TensorSpecC
,
256
,
256
,
// Gemm 0
128
,
// MPerBlock
128
,
// MPerBlock
128
,
// LPerBlock
128
,
// LPerBlock
4
,
// K
0
PerBlock
32
,
// KPerBlock
8
,
// K1
8
,
// K1
// Gemm 1
64
,
// NPerBlock
64
,
// NPerBlock
4
,
// L
0
PerBlock
32
,
// LPerBlock
8
,
// L1
8
,
// L1
16
,
// MPerWMMA
16
,
// MPerWMMA
16
,
// LPerWMMA
16
,
// LPerWMMA
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
6a9d7b64
...
@@ -53,10 +53,10 @@ template <index_t NumDimG,
...
@@ -53,10 +53,10 @@ template <index_t NumDimG,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
K
0
PerBlock
,
// K0 * K1 = Gemm0 GEMM_K Dim
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
//
ck
::
index_t
K1
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
L
0
PerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
L1
,
ck
::
index_t
L1
,
ck
::
index_t
MPerWMMA
,
ck
::
index_t
MPerWMMA
,
ck
::
index_t
LPerWMMA
,
ck
::
index_t
LPerWMMA
,
...
@@ -128,8 +128,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -128,8 +128,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
index_t
NumDimGemm1N
=
NumDimN
;
static
constexpr
index_t
NumDimGemm1N
=
NumDimN
;
static
constexpr
index_t
NumDimGemm1K
=
NumDimL
;
static
constexpr
index_t
NumDimGemm1K
=
NumDimL
;
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
K1
;
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
;
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -137,6 +135,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -137,6 +135,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
AEnableLds
=
LWaves
==
1
?
false
:
true
;
// static constexpr auto B0EnableLds = MWaves == 1 ? false : true;
// static constexpr auto B1EnableLds = MWaves == 1 ? false : true;
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
MPerBlock
,
LPerBlock
,
KPerBlock
,
NPerBlock
>
,
Sequence
<
MPerBlock
,
LPerBlock
,
KPerBlock
,
NPerBlock
>
,
...
@@ -146,12 +153,22 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -146,12 +153,22 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1Spec
,
B1Spec
,
CSpec
>
;
CSpec
>
;
static
auto
MakeAGridDescriptor
_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
static
auto
MakeAGridDescriptor
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
if
constexpr
(
AEnableLds
)
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
{
Number
<
K1
>
{});
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
K1
>
{});
}
else
{
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
WmmaK
,
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
K1
>
{})
}
}
}
static
auto
MakeB0GridDescriptor_BK0_L_BK1
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
static
auto
MakeB0GridDescriptor_BK0_L_BK1
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
...
@@ -170,7 +187,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -170,7 +187,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number
<
L1
>
{});
Number
<
L1
>
{});
}
}
using
AGridDesc
_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor
_AK0_M_AK1
({},
{}));
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
({},
{}));
using
B0GridDesc_BK0_L_BK1
=
decltype
(
MakeB0GridDescriptor_BK0_L_BK1
({},
{}));
using
B0GridDesc_BK0_L_BK1
=
decltype
(
MakeB0GridDescriptor_BK0_L_BK1
({},
{}));
using
B1GridDesc_BL0_N_BL1
=
decltype
(
MakeB1GridDescriptor_BL0_N_BL1
({},
{}));
using
B1GridDesc_BL0_N_BL1
=
decltype
(
MakeB1GridDescriptor_BL0_N_BL1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
...
@@ -250,17 +267,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -250,17 +267,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
// InMemory Data Descriptor
// InMemory Data Descriptor
AGridDesc
_AK0_M_AK1
,
AGridDesc
,
B0GridDesc_BK0_L_BK1
,
B0GridDesc_BK0_L_BK1
,
B1GridDesc_BL0_N_BL1
,
B1GridDesc_BL0_N_BL1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
// Tiling Family
// Tiling Family
MPerBlock
,
MPerBlock
,
LPerBlock
,
LPerBlock
,
K
0
PerBlock
,
// K0 * K1 = Gemm0 GEMM_K Dim
KPerBlock
,
K1
,
//
K1
,
NPerBlock
,
NPerBlock
,
L
0
PerBlock
,
LPerBlock
,
L1
,
L1
,
MPerWMMA
,
MPerWMMA
,
LPerWMMA
,
LPerWMMA
,
...
@@ -277,6 +294,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -277,6 +294,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
true
,
true
,
AEnableLds
,
ABlockLdsAddExtraM
,
ABlockLdsAddExtraM
,
B0BlockTransferThreadClusterLengths_K0_L_K1
,
B0BlockTransferThreadClusterLengths_K0_L_K1
,
B0BlockTransferThreadClusterArrangeOrder
,
B0BlockTransferThreadClusterArrangeOrder
,
...
@@ -285,6 +303,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -285,6 +303,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferDstScalarPerVector_K1
,
B0BlockTransferDstScalarPerVector_K1
,
true
,
true
,
B0EnableLds
,
B0BlockLdsAddExtraL
,
B0BlockLdsAddExtraL
,
B1BlockTransferThreadClusterLengths_L0_N_L1
,
B1BlockTransferThreadClusterLengths_L0_N_L1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
...
@@ -293,6 +312,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -293,6 +312,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_L1
,
B1BlockTransferDstScalarPerVector_L1
,
false
,
false
,
B1EnableLds
,
B1BlockLdsAddExtraN
,
B1BlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
...
@@ -338,7 +358,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -338,7 +358,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor
_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc_bk0_l_bk1_
{
DeviceOp
::
MakeB0GridDescriptor_BK0_L_BK1
(
b0_grid_desc_bk0_l_bk1_
{
DeviceOp
::
MakeB0GridDescriptor_BK0_L_BK1
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_grid_desc_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
...
@@ -404,7 +424,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -404,7 +424,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
// Tensor Descriptors
// Tensor Descriptors
AGridDesc
_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc
a_grid_desc_ak0_m_ak1_
;
B0GridDesc_BK0_L_BK1
b0_grid_desc_bk0_l_bk1_
;
B0GridDesc_BK0_L_BK1
b0_grid_desc_bk0_l_bk1_
;
B1GridDesc_BL0_N_BL1
b1_grid_desc_bl0_n_bl1_
;
B1GridDesc_BL0_N_BL1
b1_grid_desc_bl0_n_bl1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
...
@@ -463,7 +483,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -463,7 +483,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
DeviceOp
::
AGridDesc
_AK0_M_AK1
,
DeviceOp
::
AGridDesc
,
DeviceOp
::
B0GridDesc_BK0_L_BK1
,
DeviceOp
::
B0GridDesc_BK0_L_BK1
,
DeviceOp
::
B1GridDesc_BL0_N_BL1
,
DeviceOp
::
B1GridDesc_BL0_N_BL1
,
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -741,11 +761,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -741,11 +761,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
K
0
PerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
K1
<<
", "
<<
K1
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
L
0
PerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
L1
<<
L1
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
6a9d7b64
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
6a9d7b64
...
@@ -343,7 +343,7 @@ struct GridwiseGemmPipeline_v1<1, false, true>
...
@@ -343,7 +343,7 @@ struct GridwiseGemmPipeline_v1<1, false, true>
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
a_block_buf
=
a_block_buf_switch
;
//
a_block_buf = a_block_buf_switch;
++
i
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
while
(
i
<
(
num_loop
-
1
));
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
6a9d7b64
...
@@ -130,8 +130,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -130,8 +130,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
B_K0
=
BGridDesc_K0_N_K1
{}.
GetLength
(
I0
);
static
constexpr
auto
B_K1
=
BGridDesc_K0_N_K1
{}.
GetLength
(
I2
);
// FIX ME: To be deprecated
// FIX ME: To be deprecated
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
...
@@ -273,6 +271,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -273,6 +271,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
const
BBlockDesc_BK0_N_BK1
&
)
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
const
BBlockDesc_BK0_N_BK1
&
)
{
{
constexpr
auto
B_K0
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BBlockDesc_BK0_N_BK1
{},
BBlockDesc_BK0_N_BK1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
...
@@ -528,8 +529,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -528,8 +529,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}
}
}();
}();
// printf("---------------K = %d\n", K);
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b_block_desc
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_block_desc
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
...
@@ -703,7 +702,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -703,7 +702,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
/*******************************************************************************/
// Shift Per SUB_K
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
MakeABlockSliceCopyStep
();
constexpr
auto
a_block_slice_copy_step
=
MakeABlockSliceCopyStep
();
// printf("a_block_slice_copy_step FirstKdim = %d\n", a_block_slice_copy_step[I0]);
constexpr
auto
b_block_slice_copy_step
=
MakeBBlockSliceCopyStep
();
constexpr
auto
b_block_slice_copy_step
=
MakeBBlockSliceCopyStep
();
// gridwise GEMM pipeline
// gridwise GEMM pipeline
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
6a9d7b64
...
@@ -1395,34 +1395,28 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1395,34 +1395,28 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// apply element-wise operation
// apply element-wise operation
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, RawData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) ); apply intra-row swizzle permute
if
constexpr
(
IntraRowSwizzlePerm
)
if
constexpr
(
IntraRowSwizzlePerm
)
{
{
temp
=
__builtin_amdgcn_permlane16
(
// 0x76543210, 0xfedcba98
//
temp = __builtin_amdgcn_permlane16(
temp
,
//
temp,
type_convert
<
int
>
(
v_this_row
),
//
type_convert<int>(v_this_row),
0xb3a29180
,
//
0xb3a29180,
0xf7e6d5c4
,
//
0xf7e6d5c4,
1
,
//
1,
0
);
//
0);
v_this_row
=
type_convert
<
SrcData
>
(
temp
);
v_this_row
=
type_convert
<
SrcData
>
(
temp
);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, SwiData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) );
}
}
// apply inter-row permute.
// apply inter-row permute.
temp
=
__builtin_amdgcn_permlanex16
(
temp
,
//
temp = __builtin_amdgcn_permlanex16(temp,
type_convert
<
int
>
(
v_this_row
),
//
type_convert<int>(v_this_row),
LowEightRowlaneIdx
,
//
LowEightRowlaneIdx,
HighEightRowLaneIdx
,
//
HighEightRowLaneIdx,
1
,
//
1,
0
);
//
0);
v_theother_row
=
type_convert
<
SrcData
>
(
temp
);
v_theother_row
=
type_convert
<
SrcData
>
(
temp
);
// printf("tid: %03d, PermData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_theother_row)) );
if
(
get_thread_local_1d_id
()
%
32
<
16
)
if
(
get_thread_local_1d_id
()
%
32
<
16
)
{
{
// apply type convert
// apply type convert
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
6a9d7b64
...
@@ -179,6 +179,26 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -179,6 +179,26 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
AGridDesc_M_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
Number
&
WmmaK
,
const
Number
&
MRepeat
,
const
Number
&
MWaves
,
const
Number
&
MPerWmma
,
const
Number
&
AK1
)
{
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlcok
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
;
constexpr
auto
AKRow
=
WmmaK
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
Number
<
AKRow
>
{},
AK1
)),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
MWaves
,
MPerWmma
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
//
//
// B (alias of B0)
// B (alias of B0)
//
//
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment