Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4c850c90
Unverified
Commit
4c850c90
authored
Jun 19, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Jun 19, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
ce30621d
1973903f
Changes
52
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
562 additions
and
301 deletions
+562
-301
client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp
...int8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp
+1
-1
client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp
...ped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp
+1
-1
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
+4
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
+4
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+17
-25
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+19
-24
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
...or_operation/gpu/device/impl/device_contraction_utils.hpp
+38
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+6
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+6
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
...evice_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
+6
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+12
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+6
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+11
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+22
-12
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
...e_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+356
-188
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+45
-0
No files found.
client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp
View file @
4c850c90
...
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_mult
i
ply.hpp"
#include "ck/host_utility/hip_check_error.hpp"
...
...
client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp
View file @
4c850c90
...
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_mult
i
ply.hpp"
#include "ck/host_utility/hip_check_error.hpp"
...
...
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
View file @
4c850c90
...
...
@@ -63,7 +63,7 @@ using DeviceGemmInstance =
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
4
>
>
;
// clang-format on
struct
ProblemSize
final
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
View file @
4c850c90
...
...
@@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
...
...
@@ -446,12 +446,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static
constexpr
index_t
PrefetchStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
View file @
4c850c90
...
...
@@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
...
...
@@ -646,12 +646,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
4c850c90
...
...
@@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
View file @
4c850c90
...
...
@@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
GlobalBufferNum
=
2
;
static
constexpr
index_t
HotloopUnroll
=
2
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
HotloopUnroll
==
1
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
View file @
4c850c90
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access
for
(
index_t
i
=
0
;
i
<
NumATensor
;
++
i
)
{
as_mz_consecutive_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
-
1
]
==
1
;
as_kz_consecutive_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
+
NumDimK
-
1
]
==
1
;
as_max_read_elems_
[
i
]
=
tie
(
as_continous_dim_
[
i
],
as_max_read_elems_
[
i
])
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
[
i
],
a_ms_ks_strides
[
i
]);
}
for
(
index_t
i
=
0
;
i
<
NumBTensor
;
++
i
)
{
bs_nz_consecutive_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
-
1
]
==
1
;
bs_kz_consecutive_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
+
NumDimK
-
1
]
==
1
;
bs_max_read_elems_
[
i
]
=
tie
(
bs_continous_dim_
[
i
],
bs_max_read_elems_
[
i
])
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
[
i
],
b_ns_ks_strides
[
i
]);
}
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_consecutive_
[
i
]
=
d_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
ds_max_read_elems_
[
i
]
=
tie
(
ds_continous_dim_
[
i
],
ds_max_read_elems_
[
i
])
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
d_ms_ns_lengths
[
i
],
d_ms_ns_strides
[
i
]);
}
e_nz_consecutive_
=
e_ms_ns_stride
[
NumDimM
+
NumDimN
-
1
]
==
1
;
e_max_write_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_length
,
e_ms_ns_stride
);
tie
(
e_continous_dim_
,
e_max_write_elems_
)
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_length
,
e_ms_ns_stride
);
}
// pointers
...
...
@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// in the memory or not.
std
::
array
<
bool
,
NumATensor
>
as_mz_consecutive_
;
std
::
array
<
bool
,
NumATensor
>
as_kz_consecutive_
;
std
::
array
<
bool
,
NumBTensor
>
bs_nz_consecutive_
;
std
::
array
<
bool
,
NumBTensor
>
bs_kz_consecutive_
;
std
::
array
<
bool
,
NumDTensor
>
ds_nz_consecutive_
;
bool
e_nz_consecutive_
;
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
std
::
array
<
index_t
,
NumATensor
>
as_continous_dim_
;
std
::
array
<
index_t
,
NumATensor
>
bs_continous_dim_
;
std
::
array
<
index_t
,
NumBTensor
>
ds_continous_dim_
;
index_t
e_continous_dim_
;
std
::
array
<
index_t
,
NumATensor
>
as_max_read_elems_
;
std
::
array
<
index_t
,
NumBTensor
>
bs_max_read_elems_
;
...
...
@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_a_vector_size
=
arg
.
as_max_read_elems_
[
i
]
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
as_
mz_consecutive_
[
i
]
;
ABlockTransferSrcVectorDim
==
1
&&
arg
.
as_
continous_dim_
[
i
]
==
0
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
as_
kz_consecutive_
[
i
]
;
ABlockTransferSrcVectorDim
==
2
&&
arg
.
as_
continous_dim_
[
i
]
==
1
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
if
(
!
((
valid_a_vector_size
&&
valid_a_access_dim
)
||
ABlockTransferSrcScalarPerVector
==
1
))
...
...
@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_b_vector_size
=
arg
.
bs_max_read_elems_
[
i
]
%
BBlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
bs_
nz_consecutive_
[
i
]
;
BBlockTransferSrcVectorDim
==
1
&&
arg
.
bs_
continous_dim_
[
i
]
==
0
;
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
bs_
kz_consecutive_
[
i
]
;
BBlockTransferSrcVectorDim
==
2
&&
arg
.
bs_
continous_dim_
[
i
]
==
1
;
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
if
(
!
((
valid_b_vector_size
&&
valid_b_access_dim
)
||
BBlockTransferSrcScalarPerVector
==
1
))
...
...
@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_d_vector_size
=
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector read of Ds is always on N dimension.
const
bool
valid_d_access_dim
=
arg
.
ds_
nz_consecutive_
[
i
]
;
const
bool
valid_d_access_dim
=
arg
.
ds_
continous_dim_
[
i
]
==
1
;
if
(
!
((
valid_d_vector_size
&&
valid_d_access_dim
)
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
{
...
...
@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_e_vector_size
=
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
arg
.
e_
nz_consecutive_
;
const
bool
valid_e_access_dim
=
arg
.
e_
continous_dim_
==
1
;
if
(
!
((
valid_e_vector_size
&&
valid_e_access_dim
)
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
4c850c90
...
...
@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
// for sanity check of vector memory access
a_mz_consecutive_
=
a_ms_ks_strides
[
NumDimM
-
1
]
==
1
;
a_kz_consecutive_
=
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
]
==
1
;
a_max_read_elems_
=
tie
(
a_continous_dim_
,
a_max_read_elems_
)
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
b_nz_consecutive_
=
b_ns_ks_strides
[
NumDimN
-
1
]
==
1
;
b_kz_consecutive_
=
b_ns_ks_strides
[
NumDimN
+
NumDimK
-
1
]
==
1
;
b_max_read_elems_
=
tie
(
b_continous_dim_
,
b_max_read_elems_
)
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_consecutive_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
ds_max_read_elems_
[
i
]
=
tie
(
ds_continous_dim_
[
i
],
ds_max_read_elems_
[
i
])
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
}
e_nz_consecutive_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
]
==
1
;
e_max_write_elems_
=
tie
(
e_continous_dim_
,
e_max_write_elems_
)
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
}
...
...
@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// in the memory or not.
bool
a_mz_consecutive_
;
bool
a_kz_consecutive_
;
bool
b_nz_consecutive_
;
bool
b_kz_consecutive_
;
std
::
array
<
bool
,
NumDTensor
>
ds_nz_consecutive_
;
bool
e_nz_consecutive_
;
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
index_t
a_continous_dim_
;
index_t
b_continous_dim_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_continous_dim_
;
index_t
e_continous_dim_
;
index_t
a_max_read_elems_
;
index_t
b_max_read_elems_
;
...
...
@@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
bool
valid_a_vector_size
=
arg
.
a_max_read_elems_
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
a_mz_consecutive_
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
a_kz_consecutive_
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
a_continous_dim_
==
0
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
a_continous_dim_
==
1
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
||
ABlockTransferSrcScalarPerVector
==
1
;
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
...
...
@@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
bool
valid_b_vector_size
=
arg
.
b_max_read_elems_
%
BBlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
b_nz_consecutive_
;
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
b_kz_consecutive_
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
b_continous_dim_
==
0
;
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
b_continous_dim_
==
1
;
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
||
BBlockTransferSrcScalarPerVector
==
1
;
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
...
...
@@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector read of Ds is always on N dimension.
const
bool
valid_d_access_dim
=
arg
.
ds_
nz_consecutive_
[
i
]
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
arg
.
ds_
continous_dim_
[
i
]
==
1
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
{
valid_ds_access
=
false
;
...
...
@@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
arg
.
e_
nz_consecutive_
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
arg
.
e_
continous_dim_
==
1
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
{
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
View file @
4c850c90
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
}
// Determine the beginning and end idx of the group representing the FCD.
index_t
begin_idx
,
end_idx
;
if
(
strides
[
NumDim1
-
1
]
==
1
)
index_t
begin_idx
,
end_idx
,
continous_dim
,
consecutive_stride
=
1
;
if
(
strides
[
NumDim1
-
1
]
==
1
&&
strides
[
NumDim1
+
NumDim2
-
1
]
==
1
)
{
begin_idx
=
0
;
end_idx
=
NumDim1
-
1
;
// MZ or KZ are ones
bool
dims1_are_ones
=
true
;
for
(
index_t
dim_idx
=
0
;
dim_idx
<
NumDim1
;
dim_idx
++
)
{
if
(
lengths
[
dim_idx
]
!=
1
)
{
dims1_are_ones
=
false
;
}
}
if
(
dims1_are_ones
)
{
begin_idx
=
NumDim1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
continous_dim
=
1
;
}
else
{
begin_idx
=
0
;
end_idx
=
NumDim1
-
1
;
continous_dim
=
0
;
}
}
else
if
(
strides
[
NumDim1
-
1
]
==
1
)
{
begin_idx
=
0
;
end_idx
=
NumDim1
-
1
;
continous_dim
=
0
;
}
else
if
(
strides
[
NumDim1
+
NumDim2
-
1
]
==
1
)
{
begin_idx
=
NumDim1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
begin_idx
=
NumDim1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
continous_dim
=
1
;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return
1
;
consecutive_stride
=
1
;
continous_dim
=
0
;
return
make_tuple
(
continous_dim
,
consecutive_stride
);
}
index_t
consecutive_stride
=
1
;
for
(
index_t
dim_idx
=
end_idx
;
dim_idx
>=
begin_idx
;
--
dim_idx
)
{
if
(
strides
[
dim_idx
]
==
consecutive_stride
)
...
...
@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
}
}
const
index_t
max_subsequent_elems
=
consecutive_stride
;
return
max_subsequent_elems
;
return
make_tuple
(
continous_dim
,
max_subsequent_elems
)
;
}
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
4c850c90
...
...
@@ -93,9 +93,12 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
4c850c90
...
...
@@ -54,9 +54,12 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
c_batch_offset
=
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
c_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
));
__shared__
FloatAB
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
)];
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
View file @
4c850c90
...
...
@@ -66,9 +66,12 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
c_batch_offset
=
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
c_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
));
__shared__
FloatA
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatA
)];
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
4c850c90
...
...
@@ -59,9 +59,12 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
NumBatchToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -113,9 +116,12 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
NumBatchToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
4c850c90
...
...
@@ -97,9 +97,12 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
c_batch_offset
=
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
c_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
4c850c90
...
...
@@ -106,10 +106,12 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_batch
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_n
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
);
const
auto
&
ds_batch_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
const
auto
&
ds_batch_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
long_index_t
e_n_offset
=
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
);
const
long_index_t
e_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -170,10 +172,13 @@ __global__ void
}
else
{
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
const
long_index_t
a_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
+
a_batch_offset
+
a_n_offset
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
4c850c90
...
...
@@ -85,12 +85,17 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_batch
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_n
);
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
);
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
const
long_index_t
e_n_offset
=
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
const
long_index_t
a_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
));
const
long_index_t
e_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -142,12 +147,17 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_batch
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_n
);
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
);
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
const
long_index_t
e_n_offset
=
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
const
long_index_t
a_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
));
const
long_index_t
e_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
));
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
4c850c90
...
...
@@ -161,11 +161,11 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn
_readfirstlane
(
const
long_index_t
a_batch_offset
=
amd_wave
_read
_
first
_
lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn
_readfirstlane
(
const
long_index_t
b_batch_offset
=
amd_wave
_read
_
first
_
lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn
_readfirstlane
(
const
long_index_t
e_batch_offset
=
amd_wave
_read
_
first
_
lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
4c850c90
...
...
@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
...
...
@@ -42,16 +43,22 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
GemmSpecialization
GemmSpec
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
index_t
KPerBlock
,
typename
OffsettedBlockToCTileMap
,
typename
LocalBlock2ETileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
typename
CDEElementwiseOperation
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -67,6 +74,7 @@ __global__ void
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared1
[
shared_size
];
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
...
...
@@ -81,27 +89,8 @@ __global__ void
index_t
gemm_tile_id_start
=
0
;
index_t
gemm_tile_id_end
=
0
;
using
AGridDescMK
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
1
,
1
,
1
))
>
;
using
BGridDescNK
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
1
,
1
,
1
))
>
;
using
EGridDescMN
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
1
,
1
,
1
))
>
;
using
DsGridDescMN
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
index_t
M
=
0
,
N
=
0
,
K
=
0
;
index_t
StrideA
,
StrideB
,
StrideE
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
AGridDescMK
a_grid_desc_mk
;
BGridDescNK
b_grid_desc_nk
;
EGridDescMN
e_grid_desc_mn
;
DsGridDescMN
ds_grid_desc_mn
;
auto
b2c_tile_map
=
OffsettedBlockToCTileMap
(
LocalBlock2ETileMap
(
1
,
1
),
1
,
1
);
do
...
...
@@ -127,31 +116,13 @@ __global__ void
}
b2c_tile_map
=
OffsettedBlockToCTileMap
(
LocalBlock2ETileMap
(
M
,
N
),
group_offset
,
tile_offset
);
OffsettedBlockToCTileMap
(
LocalBlock2ETileMap
(
M
,
N
,
4
),
group_offset
,
tile_offset
);
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
gemm_tile_id_start
=
group_offset
;
gemm_tile_id_end
=
group_offset
+
grid_size_grp
;
}
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
StrideDs
=
gemm_desc_ptr
[
group_id
].
StrideDs
;
StrideE
=
gemm_desc_ptr
[
group_id
].
StrideE
;
a_grid_desc_mk
=
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
M
,
K
,
StrideA
);
b_grid_desc_nk
=
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
K
,
N
,
StrideB
);
e_grid_desc_mn
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_mn
(
j
)
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>(
M
,
N
,
StrideDs
[
j
]);
});
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
DsGridPointer
p_ds_grid
;
...
...
@@ -160,42 +131,268 @@ __global__ void
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
bool
has_main_kblock_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_mk
.
GetLength
(
Number
<
1
>
{}));
static
constexpr
index_t
kbatch
=
1
;
static
constexpr
index_t
k_grain
=
kbatch
*
KPerBlock
;
index_t
K_split
=
(
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
// Update tile offset if we have moved within group
b2c_tile_map
.
UpdateTileOffset
(
tile_offset
);
if
(
has_main_kblock_loop
)
using
Problem
=
typename
GridwiseGemm
::
Problem
;
auto
problem
=
Problem
(
gemm_desc_ptr
[
group_id
].
M
,
gemm_desc_ptr
[
group_id
].
N
,
gemm_desc_ptr
[
group_id
].
K
,
gemm_desc_ptr
[
group_id
].
StrideA
,
gemm_desc_ptr
[
group_id
].
StrideB
,
gemm_desc_ptr
[
group_id
].
StrideDs
,
gemm_desc_ptr
[
group_id
].
StrideE
,
kbatch
);
if
(
has_main_k_block_loop
)
{
GridwiseGemm
::
template
Run
<
true
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
p_ds_grid
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
static_cast
<
void
*>
(
p_shared
),
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_mk
,
b_grid_desc_nk
,
ds_grid_desc_mn
,
e_grid_desc_mn
,
b2c_tile_map
);
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Full
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
One
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Full
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Two
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Three
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Four
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Five
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Six
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Seven
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
GridwiseGemm
::
template
Run_2Lds
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Odd
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
static_cast
<
void
*>
(
p_shared1
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
else
{
GridwiseGemm
::
template
Run_2Lds
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Even
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
static_cast
<
void
*>
(
p_shared1
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
}
else
{
GridwiseGemm
::
template
Run
<
false
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
p_ds_grid
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
static_cast
<
void
*>
(
p_shared
),
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_mk
,
b_grid_desc_nk
,
ds_grid_desc_mn
,
e_grid_desc_mn
,
b2c_tile_map
);
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
false
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Full
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
tile_id
+=
get_grid_size
();
...
...
@@ -253,10 +450,12 @@ template <typename ALayout,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
typename
ComputeDataType
=
EDataType
>
typename
CDEShuffleBlockTransferScalarPerVectors
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
EDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
:
public
DeviceGroupedGemmTileLoop
<
ALayout
,
BLayout
,
...
...
@@ -273,10 +472,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
using
DeviceOp
=
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultiD_xdl_cshuffle_v3
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
@@ -284,8 +486,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
GemmSpec
,
BlockSize
,
MPerBlock
,
NPerBlock
,
...
...
@@ -315,58 +516,15 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMap
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
__host__
__device__
OffsettedBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
group_offset
,
index_t
tile_offset
)
:
block_to_ctile_map_
{
block_to_ctile_map
},
group_offset_
{
group_offset
},
tile_offset_
{
tile_offset
}
{
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
+
tile_offset_
-
group_offset_
));
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_to_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
CDEShuffleBlockTransferScalarPerVectors
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
__host__
__device__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
M
,
N
);
}
__device__
void
UpdateTileOffset
(
index_t
offset
)
{
tile_offset_
=
offset
;
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
group_offset_
;
index_t
tile_offset_
;
};
using
KernelArguments
=
GroupedGemmTileLoopKernelArguments
<
NumDTensor
>
;
using
Block2ETileMap
=
BlockToCTileMap_N00_M0_N01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
OffsetedLocalBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMap
>
;
using
KernelArguments
=
GroupedGemmTileLoopKernelArguments
<
NumDTensor
>
;
using
Block2ETileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
OffsettedLocalBlock2ETileMap
=
OffsettedBlockToCTileMap2
<
Block2ETileMap
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -403,7 +561,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const
void
*
p_dev_gemm_args_
;
int
occupancy_num_blocks_
;
int
gpu_cu_count_
;
const
std
::
vector
<
GemmDesc
>&
gemm_descs_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
...
...
@@ -496,16 +653,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const
auto
kernel
=
kernel_grouped_gemm_multiple_d_xdl
<
GridwiseGemm
,
KernelArguments
,
GemmSpec
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
OffsetedLocalBlock2ETileMap
,
KPerBlock
,
OffsettedLocalBlock2ETileMap
,
Block2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
;
CDEElementwiseOperation
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
>
;
return
LaunchKernel
(
kernel
,
arg
,
dev_gemm_args
,
stream_config
);
}
...
...
@@ -546,6 +709,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
<<
std
::
endl
;
}
// run multiple kernels
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
...
...
@@ -572,63 +737,41 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return
false
;
}
using
DsGridDescMN
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
bool
supported
=
true
;
for
(
const
auto
&
gdesc
:
arg
.
gemm_descs_
)
constexpr
index_t
k_batch
=
1
;
for
(
index_t
i
=
0
;
i
<
arg
.
group_count_
;
++
i
)
{
const
auto
M
=
gdesc
.
M_
;
const
auto
N
=
gdesc
.
N_
;
const
auto
K
=
gdesc
.
K_
;
const
auto
StrideA
=
gdesc
.
stride_A_
;
const
auto
StrideB
=
gdesc
.
stride_B_
;
const
auto
StrideE
=
gdesc
.
stride_C_
;
const
auto
&
StrideDs
=
gdesc
.
stride_Ds_
;
// If M dimension is unknown at launch time then validate just NK.
// If N or K dim is zero (or unknown) then the vector loads responsibility lies on
// the user.
if
(
N
*
K
==
0
)
continue
;
const
auto
a_grid_desc_mk
=
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
M
,
K
,
StrideA
);
const
auto
b_grid_desc_nk
=
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
K
,
N
,
StrideB
);
const
auto
e_grid_desc_mn
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
DsGridDescMN
ds_grid_desc_mn
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_mn
(
j
)
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>(
M
,
N
,
StrideDs
[
j
]);
});
const
auto
b2c_tile_map
=
Block2ETileMap
(
M
,
N
);
if
(
!
(
GridwiseGemm
::
template
CheckValidity
(
a_grid_desc_mk
,
b_grid_desc_nk
,
ds_grid_desc_mn
,
e_grid_desc_mn
,
b2c_tile_map
)
&&
GridwiseGemm
::
template
CheckTensorTransfersValidity
<
ALayout
,
BLayout
,
ELayout
>(
M
,
N
,
K
)))
std
::
array
<
const
void
*
,
NumDTensor
>
placeholder_p_ds_grid
{};
std
::
array
<
index_t
,
NumDTensor
>
stride_Ds
;
std
::
copy_n
(
arg
.
gemm_descs_
[
i
].
stride_Ds_
.
begin
(),
NumDTensor
,
stride_Ds
.
begin
());
using
GridArg
=
typename
GridwiseGemm
::
Argument
;
GridArg
gridwise_arg
(
nullptr
,
// p_a_grid,
nullptr
,
// p_b_grid,
placeholder_p_ds_grid
,
// p_ds_grid,
nullptr
,
// p_e_grid ,
arg
.
gemm_descs_
[
i
].
M_
,
arg
.
gemm_descs_
[
i
].
N_
,
arg
.
gemm_descs_
[
i
].
K_
,
arg
.
gemm_descs_
[
i
].
stride_A_
,
arg
.
gemm_descs_
[
i
].
stride_B_
,
stride_Ds
,
arg
.
gemm_descs_
[
i
].
stride_C_
,
k_batch
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
);
if
((
arg
.
gemm_descs_
[
i
].
K_
%
AK1
!=
0
||
arg
.
gemm_descs_
[
i
].
K_
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"The provided GEMM problem size (M,N,K) ["
<<
M
<<
","
<<
N
<<
","
<<
K
<<
"] are not supported by current template parameters!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
}
supported
=
false
;
return
false
;
}
supported
=
supported
&&
GridwiseGemm
::
CheckValidity
(
gridwise_arg
);
}
return
supported
;
...
...
@@ -651,16 +794,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const
auto
kernel
=
kernel_grouped_gemm_multiple_d_xdl
<
GridwiseGemm
,
KernelArguments
,
GemmSpec
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
OffsetedLocalBlock2ETileMap
,
KPerBlock
,
OffsettedLocalBlock2ETileMap
,
Block2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
;
CDEElementwiseOperation
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
...
...
@@ -696,16 +845,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const
auto
kernel
=
kernel_grouped_gemm_multiple_d_xdl
<
GridwiseGemm
,
KernelArguments
,
GemmSpec
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
OffsetedLocalBlock2ETileMap
,
KPerBlock
,
OffsettedLocalBlock2ETileMap
,
Block2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
;
CDEElementwiseOperation
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
...
...
@@ -739,6 +894,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
{
auto
str
=
std
::
ostringstream
();
std
::
map
<
BlockGemmPipelineScheduler
,
std
::
string
>
BlkGemmPipelineSchedulerToString
{
{
BlockGemmPipelineScheduler
::
Intrawave
,
"Intrawave"
},
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
// clang-format off
str
<<
"DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
<<
"<"
...
...
@@ -760,8 +926,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
PipelineVer
<<
", "
<<
LoopSched
<<
"BlkGemmPipelineScheduler: "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
4c850c90
...
...
@@ -908,6 +908,51 @@ struct OffsettedBlockToCTileMap
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
};
// second version with 2 offsets
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMap2
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
__host__
__device__
OffsettedBlockToCTileMap2
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
group_offset
,
index_t
tile_offset
)
:
block_to_ctile_map_
{
block_to_ctile_map
},
group_offset_
{
group_offset
},
tile_offset_
{
tile_offset
}
{
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
+
tile_offset_
-
group_offset_
));
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_to_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
__host__
__device__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
M
,
N
);
}
__device__
void
UpdateTileOffset
(
index_t
offset
)
{
tile_offset_
=
offset
;
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
group_offset_
;
index_t
tile_offset_
;
};
/**
* @brief Simple tile mapping which creates 3D grid of block of threads.
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment