Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
563c1e23
Commit
563c1e23
authored
Nov 29, 2024
by
Astha Rai
Browse files
Merge branch 'codegen_hiprtc' of github.com:ROCm/composable_kernel into codegen_hiprtc
parents
fb1bc1bf
b14b1973
Changes
343
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2571 additions
and
245 deletions
+2571
-245
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
...e/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
+1014
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
+309
-73
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+6
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+58
-35
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+21
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+19
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+55
-17
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+36
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
+733
-85
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+6
-5
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+0
-1
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+2
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+103
-0
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+18
-0
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+24
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+5
-0
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+36
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+80
-6
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+45
-9
include/ck_tile/core/tensor/shuffle_tile.hpp
include/ck_tile/core/tensor/shuffle_tile.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
0 → 100644
View file @
563c1e23
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace
ck
{
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
// kernel function Blockers:
// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
// two lds chunks.
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
// buffer when we declare __shared__ inside blkgemmpipe
template
<
typename
GridwiseGemm
,
typename
BatchedGemmArg
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
(
BatchedGemmArg
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
g_idx
=
blockIdx
.
z
%
karg
.
Batch
;
const
auto
a_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
auto
b_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
c_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
// populate pointer, desc for Ds
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// D pointer
karg
.
p_ds_grid
(
i
)
=
karg
.
p_ds_grid
(
i
)
+
ds_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
+
c_batch_offset
,
p_shared
,
karg
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
}
template
<
typename
GridwiseGemm
,
typename
BatchedGemmArg
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds
(
BatchedGemmArg
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
g_idx
=
blockIdx
.
z
%
karg
.
Batch
;
const
auto
a_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
auto
b_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
c_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
// populate pointer, desc for Ds
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// D pointer
karg
.
p_ds_grid
(
i
)
=
karg
.
p_ds_grid
(
i
)
+
ds_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
+
c_batch_offset
,
p_shared_0
,
p_shared_1
,
karg
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
}
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDEShuffleBlockTransferScalarPerVectors
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
ADataType
,
typename
ComputeTypeB
=
BDataType
,
typename
LDSTypeA
=
ComputeTypeA
,
typename
LDSTypeB
=
ComputeTypeB
>
struct
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
:
public
DeviceBatchedGemmV2MultiD
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultiD_xdl_cshuffle_v3
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
GemmAccDataType
,
CShuffleDataType
,
DsDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVectors
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
,
LDSTypeA
,
LDSTypeB
>
;
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideC
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideC_
(
BatchStrideC
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
static_cast
<
long_index_t
>
(
BatchStrideA_
)
*
g_idx
;
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
static_cast
<
long_index_t
>
(
BatchStrideB_
)
*
g_idx
;
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset_
;
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_offset_
[
i
]
=
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
])
*
g_idx
;
});
return
ds_offset_
;
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
static_cast
<
long_index_t
>
(
BatchStrideC_
)
*
g_idx
;
}
private:
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
const
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideC_
;
};
struct
Argument
:
public
GridwiseGemm
::
Argument
{
index_t
Batch
;
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch
;
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
CDataType
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
,
index_t
BatchStrideA_
,
index_t
BatchStrideB_
,
const
std
::
array
<
ck
::
index_t
,
NumDTensor
>&
BatchStrideDs_
,
index_t
BatchStrideE_
,
index_t
Batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
:
GridwiseGemm
::
Argument
{
p_a_grid_
,
p_b_grid_
,
p_ds_grid_
,
p_e_grid_
,
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideDs_
,
StrideE_
,
1
,
a_element_op_
,
b_element_op_
,
c_element_op_
},
Batch
{
Batch_
},
compute_ptr_offset_of_batch
{
BatchStrideA_
,
BatchStrideB_
,
BatchStrideDs_
,
BatchStrideE_
}
{
}
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
)
||
arg
.
KBatch
>
1
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
Batch
);
float
ave_time
=
0
;
index_t
k_grain
=
arg
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
stream_config
.
flush_cache
)
{
std
::
array
<
std
::
size_t
,
NumDTensor
>
DsSize
;
Argument
arg_
=
arg
;
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
arg_
.
M
,
arg_
.
MPadded
,
arg_
.
K
,
arg_
.
KPadded
,
arg_
.
StrideA
,
arg_
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
GridwiseGemm
::
MakeBGridDescriptor_BK0_N_BK1
(
arg_
.
K
,
arg_
.
KPadded
,
arg_
.
N
,
arg_
.
NPadded
,
arg_
.
StrideB
,
arg_
.
BK0
);
auto
size_a_buffer
=
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
*
arg
.
Batch
;
auto
size_b_buffer
=
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
*
arg
.
Batch
;
const
auto
ds_grid_desc_m_n
=
GridwiseGemm
::
MakeDsGridDescriptor_M_N
(
arg_
.
M
,
arg_
.
MPadded
,
arg_
.
N
,
arg_
.
NPadded
,
arg_
.
StrideDs
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
DsSize
[
i
]
=
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
()
*
sizeof
(
DDataType
);
});
ck
::
utility
::
RotatingMemWrapperMultiD
<
Argument
,
DsDataType
>
rotating_mem
(
arg_
,
stream_config
.
rotating_count
,
size_a_buffer
,
size_b_buffer
,
DsSize
);
rotating_mem
.
Print
();
auto
run_flush_cache
=
[
&
]()
{
// flush icache
ck
::
utility
::
flush_icache
();
// rotating mem
rotating_mem
.
Next
();
// clear c mem
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
run_flush_cache
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg_
);
}
else
{
if
(
arg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
};
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
if
(
arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
else
{
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
if
(
arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v3_multi_d
<
GridwiseGemm
,
Argument
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
(
!
is_bf16_atomic_supported
()
&&
std
::
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
&&
arg
.
KBatch
>
1
)
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
Batch
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
const
std
::
array
<
ck
::
index_t
,
NumDTensor
>&
BatchStrideDs
,
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
p_ds
,
static_cast
<
CDataType
*>
(
p_e
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
BatchStrideA
,
BatchStrideB
,
BatchStrideDs
,
BatchStrideE
,
Batch
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
Batch
,
index_t
StrideA
,
index_t
StrideB
,
const
std
::
array
<
ck
::
index_t
,
NumDTensor
>&
StrideDs
,
index_t
StrideE
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
const
std
::
array
<
ck
::
index_t
,
NumDTensor
>&
BatchStrideDs
,
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
p_ds
,
static_cast
<
CDataType
*>
(
p_e
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
BatchStrideA
,
BatchStrideB
,
BatchStrideDs
,
BatchStrideE
,
Batch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
BlockGemmPipelineScheduler
,
std
::
string
>
BlkGemmPipelineSchedulerToString
{
{
BlockGemmPipelineScheduler
::
Intrawave
,
"Intrawave"
},
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
// clang-format off
str
<<
"DeviceBatchedGemmXdlUniversal"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
std
::
string
(
CLayout
::
name
)[
0
]
<<
">"
<<
" BlkSize: "
<<
BlockSize
<<
", "
<<
"BlkTile: "
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
", "
<<
"WaveTile: "
<<
MPerXDL
<<
"x"
<<
NPerXDL
<<
", "
<<
"WaveMap: "
<<
MXdlPerWave
<<
"x"
<<
NXdlPerWave
<<
", "
<<
"VmemReadVec: "
<<
ABlockTransferSrcScalarPerVector
<<
"x"
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
"BlkGemmPipelineScheduler: "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
"BlkGemmPipelinePrefetchStages: "
<<
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
100644 → 100755
View file @
563c1e23
...
...
@@ -131,6 +131,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
...
...
@@ -147,26 +148,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Atomic
)
{
hip_check_error
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
dim3
grid_dim
;
if
(
arg
.
Grid_size
<
0
)
{
int
occupancy
,
num_cu
;
hipError_t
rtn
;
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
);
hip_check_error
(
rtn
);
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
rtn
);
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
hip_check_error
(
rtn
);
num_cu
=
dev_prop
.
multiProcessorCount
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
arg
.
Grid_size
=
num_cu
*
occupancy
;
grid_dim
=
arg
.
Grid_size
;
}
...
...
@@ -196,8 +198,31 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
else
{
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Atomic
)
{
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
}
else
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
char
*
workspace_semaphore
=
reinterpret_cast
<
char
*>
(
arg
.
p_workspace_
)
+
arg
.
block_2_ctile_map_streamk
.
get_workspace_size_for_acc
(
sizeof
(
GemmAccDataType
));
auto
preprocess
=
[
&
]()
{
hipMemsetAsync
(
workspace_semaphore
,
0
,
// sizeof(uint32_t),
arg
.
block_2_ctile_map_streamk
.
get_workspace_size_for_semaphore
(),
stream_config
.
stream_id_
);
};
ave_time
=
launch_and_time_kernel_with_preprocess
(
stream_config
,
preprocess
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
}
}
};
...
...
@@ -211,14 +236,12 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
...
...
@@ -340,53 +363,49 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
...
...
@@ -396,14 +415,11 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
...
...
@@ -418,6 +434,29 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
p_arg
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
return
p_arg
->
block_2_ctile_map_streamk
.
get_workspace_size
(
sizeof
(
GemmAccDataType
));
}
else
{
return
0
;
}
}
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
}
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
...
...
@@ -464,8 +503,205 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
CElementwiseOperation
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
streamk_sel
,
Grid_size
};
// HS
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
index_t
K_split
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
int
occupancy
,
num_cu
;
const
auto
calculate_grid_size
=
[
&
](
const
auto
&
kernel
)
{
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
Grid_size
=
num_cu
*
occupancy
;
};
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
calculate_grid_size
(
kernel
);
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
calculate_grid_size
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
calculate_grid_size
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
calculate_grid_size
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
calculate_grid_size
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
calculate_grid_size
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
calculate_grid_size
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
calculate_grid_size
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
calculate_grid_size
(
kernel
);
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
calculate_grid_size
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
calculate_grid_size
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
calculate_grid_size
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
calculate_grid_size
(
kernel
);
}
}
}
else
{
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
calculate_grid_size
(
kernel
);
}
}
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
streamk_sel
,
Grid_size
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
563c1e23
...
...
@@ -381,10 +381,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
{
tildes
=
{
i_ztilde
,
i_ytilde
,
i_xtilde
};
}
else
{
throw
std
::
runtime_error
(
"wrong! only implemented for 2D and 3D now"
);
}
const
auto
a_grid_desc_ak0_m_ak1
=
transform_conv_to_gemm
.
template
MakeADescriptor_AK0_M_AK1
<
ALayout
>(
...
...
@@ -750,6 +746,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
}
}
// check number of dimension, only implemented for 2D and 3D now
if
(
NDimSpatial
!=
2
&&
NDimSpatial
!=
3
)
{
return
false
;
}
return
true
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
563c1e23
...
...
@@ -18,7 +18,6 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...
...
@@ -78,17 +77,17 @@ template <typename ALayout,
// TODO: change gridwise_gemm_v2r4r2 to support AK1 & BK1
enable_if_t
<
AK1
==
BK1
,
bool
>
=
false
>
struct
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
:
public
DeviceGroupedGemm
MultipleD
SplitK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
:
public
DeviceGroupedGemmSplitK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
;
...
...
@@ -530,7 +529,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
index_t
skipped_group_count_
;
index_t
grid_size_
;
// Pointer to device memory with GEMM kernel arguments.
const
void
*
p_dev_gemm_args_
;
void
*
p_dev_gemm_
k
args_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
...
...
@@ -566,7 +565,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
/// @return The average kernel execution time (if time measurement is enabled.)
///
float
Run
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_args
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
...
...
@@ -621,7 +620,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
///
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
arg
.
p_dev_gemm_args_
==
nullptr
)
if
(
arg
.
p_dev_gemm_
k
args_
==
nullptr
)
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments device buffer is not allocated!"
...
...
@@ -637,7 +636,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
throw
std
::
runtime_error
(
err
.
str
());
}
return
Run
(
arg
,
arg
.
p_dev_gemm_args_
,
arg
.
p_workspace_
,
stream_config
);
return
Run
(
arg
,
arg
.
p_dev_gemm_
k
args_
,
arg
.
p_workspace_
,
stream_config
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
...
...
@@ -723,7 +722,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
template
<
bool
HasMainKBlockLoop
>
float
DispatchKernel
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_
k
args
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
)
const
{
...
...
@@ -746,7 +745,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return
LaunchKernel
(
gemm_kernel
,
elementwise_kernel
,
arg
,
dev_gemm_args
,
dev_gemm_
k
args
,
dev_gemm_workspace
,
stream_config
);
}
...
...
@@ -755,12 +754,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float
LaunchKernel
(
const
KernelFunction
&
gemm_kernel
,
const
KernelFunction2
&
elementwise_kernel
,
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_
k
args
,
[[
maybe_unused
]]
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
)
const
{
float
time
{
0.
f
};
hip_check_error
(
hipMemcpyWithStream
(
dev_gemm_kargs
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
auto
preprocess
=
[
&
]()
{
hip_check_error
(
hipMemsetAsync
(
dev_gemm_workspace
,
0
,
arg
.
GetWorkspaceSizeBytes
(),
stream_config
.
stream_id_
));
...
...
@@ -774,7 +780,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
dev_gemm_args
),
cast_pointer_to_constant_address_space
(
dev_gemm_
k
args
),
arg
.
gemm_kernel_args_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
@@ -930,18 +936,30 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return
str
.
str
();
}
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
)
const
void
SetDeviceKernelArgs
(
Base
Argument
*
p_
arg
,
void
*
p_dev_kernel_args
)
const
override
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
hip_check_error
(
hipMemcpy
(
p_dev_kernel_args
,
arg
.
gemm_kernel_args_
.
data
(),
GetDeviceKernelArgSize
(
&
arg
),
hipMemcpyHostToDevice
));
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
p_dev_gemm_kargs_
=
p_dev_kernel_args
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
void
S
etDeviceKernelArg
s
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
size_t
G
etDeviceKernelArg
Size
(
const
BaseArgument
*
p_arg
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
auto
arg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg
)
{
return
arg
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
...
...
@@ -974,17 +992,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
[[
deprecated
]]
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
arg
.
UpdateKBatch
(
kbatch
);
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
);
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
p_arg_
->
UpdateKBatch
(
kbatch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
563c1e23
...
...
@@ -20,7 +20,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace
ck
{
...
...
@@ -522,7 +521,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
ComputeTypeA
,
ComputeTypeB
>
;
using
KernelArguments
=
GroupedGemm
TileLoop
KernelArgument
s
<
NumDTensor
>
;
using
KernelArguments
=
GroupedGemmKernelArgument
<
NumDTensor
>
;
using
Block2ETileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
OffsettedLocalBlock2ETileMap
=
OffsettedBlockToCTileMap2
<
Block2ETileMap
>
;
...
...
@@ -936,12 +935,31 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return
str
.
str
();
}
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
,
const
void
*
p_host_kernel_args
)
const
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
hip_check_error
(
hipMemcpy
(
p_dev_kernel_args
,
p_host_kernel_args
,
GetDeviceKernelArgSize
(
&
arg
),
hipMemcpyHostToDevice
));
}
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
,
const
void
*
p_host_kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
,
p_host_kernel_args
);
}
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
)
const
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
563c1e23
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -717,7 +717,24 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmBiasTransKernelArg
);
auto
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
return
p_arg_
->
group_count_
*
sizeof
(
GemmBiasTransKernelArg
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!"
);
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
GetWorkSpaceSize
(
p_arg
);
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
return
this
->
SetWorkSpacePointer
(
p_arg
,
p_dev_kernel_args
);
}
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
563c1e23
...
...
@@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
using
Block2ETileMap
=
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
<
MPerBlock
,
NPerBlock
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMapMLoops
<
Block2ETileMap
>
;
// TODO: replace with GroupedGemmKernelArgument
struct
GemmBiasTransKernelArg
{
// pointers
...
...
@@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return
str
.
str
();
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
kernel_args
)
{
arg
.
grouped_gemm_kernel_args_dev
=
kernel_args
;
}
// polymorphic
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
override
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kernel_args
);
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
grouped_gemm_kernel_args_dev
=
kernel_args
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
return
arg
.
group_count_
*
arg
.
barrier_size_grp_
*
sizeof
(
uint32_t
);
auto
arg_ptr
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
return
arg_ptr
->
group_count_
*
arg_ptr
->
barrier_size_grp_
*
sizeof
(
uint32_t
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
return
arg
.
group_count_
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
);
auto
arg_ptr
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
return
arg_ptr
->
group_count_
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
const
override
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
p_arg_
->
p_workspace_
=
p_workspace
;
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
p_workspace_
=
p_workspace
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
hip_check_error
(
hipMemsetAsync
(
p_workspace
,
0
,
GetWorkSpaceSize
(
p_
arg
),
stream_config
.
stream_id_
));
hipMemsetAsync
(
p_workspace
,
0
,
GetWorkSpaceSize
(
arg
_ptr
),
stream_config
.
stream_id_
));
}
static
void
SetKBatch
(
Argument
&
arg
,
index_t
k_batch
)
{
arg
.
UpdateKBatch
(
k_batch
);
}
...
...
@@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// polymorphic
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
override
{
return
SetKBatch
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
k_batch
);
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
UpdateKBatch
(
k_batch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
UpdateKBatch
(
kbatch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
563c1e23
...
...
@@ -538,10 +538,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return
false
;
}
if
(
std
::
is_same_v
<
EDataType
,
ck
::
bhalf_t
>
&&
arg
.
K_BATCH
>
1
&&
!
is_bf16_atomic_supported
())
{
return
false
;
}
bool
supported
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
if
(
not
group_arg_valid
)
{
...
...
@@ -631,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
);
auto
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
return
p_arg_
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"
);
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
GetWorkSpaceSize
(
p_arg
);
}
// TODO: deperecation notice.
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
// polymorphic
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
p_arg_
->
UpdateKBatch
(
kbatch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"
);
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
return
this
->
SetWorkSpacePointer
(
p_arg
,
p_dev_kernel_args
);
}
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
100644 → 100755
View file @
563c1e23
...
...
@@ -14,6 +14,8 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
...
...
@@ -38,7 +40,7 @@ __global__ void
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
);
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
,
karg
.
p_workspace_
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
...
...
@@ -62,7 +64,13 @@ __global__ void
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
);
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
,
karg
.
p_workspace_
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
...
...
@@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
Streamk_sel_
,
Grid_size_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
p_c_grid
{
p_c_grid_
},
block_2_ctile_map_streamk
(
M_
,
N_
,
AK0Number
*
CalculateKPadded
(
K_
,
1
),
Grid_size_
,
Streamk_sel_
)
{
}
...
...
@@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
CDataType
*
p_c_grid
;
BlockToCTileMap_GemmStreamK_v2
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
StreamKReductionStrategy
::
Atomic
,
8
,
4
>
block_2_ctile_map_streamk
;
};
struct
SplitKBatchOffset
...
...
@@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MXdlPerWave
/
CShuffleMXdlPerWavePerShuffle
>
{},
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
Number
<
NXdlPerWave
/
CShuffleNXdlPerWavePerShuffle
>
{},
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
}
using
BlockwiseGemmPipe
=
remove_cvref_t
<
decltype
(
BlockGemmPipeline_Selector
<
BlkGemmPipelineVer
,
...
...
@@ -1118,6 +1148,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
__host__
__device__
static
constexpr
auto
GetClusterLengthReduction
()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr
auto
NPerBlockPow2
=
math
::
next_power_of_two
<
NPerBlock
>
();
constexpr
auto
NPerBlockReduction
=
NPerBlockPow2
/
CShuffleBlockTransferScalarPerVector_NPerBlock
;
constexpr
auto
MPerBlockReduction
=
(
BlockSize
+
NPerBlockReduction
-
1
)
/
NPerBlockReduction
;
return
Sequence
<
MPerBlockReduction
,
NPerBlockReduction
>
{};
}
__host__
__device__
static
constexpr
auto
GetPartialAccBlockDescriptor
()
{
const
auto
c_partial_acc_block_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MPerBlock
,
NPerBlock
),
make_tuple
(
NPerBlock
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MPerBlock
,
NPerBlock
),
make_tuple
(
I1
,
MPerBlock
));
}
}();
return
c_partial_acc_block_m_n
;
}
using
Block2CTileMap_streamk
=
BlockToCTileMap_GemmStreamK_v2
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
...
...
@@ -1132,22 +1190,42 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
Problem
&
problem
)
Problem
&
problem
,
void
*
p_workspace
)
{
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
,
problem
.
Streamk_sel
);
uint32_t
iter_start
,
iter_end
;
bool
is_sk_block
,
is_dp_block
;
bool
is_sk_block
,
is_dp_block
,
is_reduction_block
;
index_t
num_k_block_main_loop
;
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
uint32_t
*
p_semaphore
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
p_workspace
)
+
block_2_ctile_map_streamk
.
get_workspace_size_for_acc
(
sizeof
(
AccDataType
)));
for
(
auto
block_idx
=
get_block_1d_id
();
block_idx
<
block_2_ctile_map_streamk
.
get_grid_dims
();
block_idx
+=
gridDim
.
x
)
...
...
@@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
num_k_block_main_loop
=
iter_end
-
iter_start
;
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
is_reduction_block
=
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
if
(
is_reduction_block
)
{
// descriptors
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
constexpr
auto
reduce_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
const
auto
reduce_thread_cluster_idx
=
reduce_desc
.
CalculateBottomIndex
(
make_multi_index
(
block_idx
));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
constexpr
auto
MReduceIters
=
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
constexpr
auto
NReduceIters
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
cluster_length_reduce
.
At
(
I1
)
*
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{});
constexpr
auto
acc_thread_buf_load_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{}));
constexpr
auto
acc_thread_buf_store_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{}));
constexpr
auto
c_partial_acc_block_m_n
=
GetPartialAccBlockDescriptor
();
constexpr
auto
partial_acc_load_step_n
=
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_load_step_n_reverse
=
make_multi_index
(
0
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_load_step_m
=
make_multi_index
(
cluster_length_reduce
.
At
(
I0
),
0
);
constexpr
auto
partial_acc_store_step_n
=
make_multi_index
(
0
,
0
,
0
,
cluster_length_reduce
.
At
(
I1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_store_step_n_reverse
=
make_multi_index
(
0
,
0
,
0
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_store_step_m
=
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I0
),
0
,
0
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
true
>
parcial_acc_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
true
>
acc_buf
;
// start to compute
auto
reduction_idx
=
block_idx
-
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
auto
spatial_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
reduction_idx
,
problem
.
M
,
problem
.
N
);
workgroup_barrier
wg_barrier
(
p_semaphore
);
uint32_t
tile_acc_offset_start
=
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_tile
(
reduction_idx
);
uint32_t
tile_acc_offset_end
=
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_tile
(
reduction_idx
+
1
);
__syncthreads
();
auto
acc_load
=
ThreadwiseTensorSliceTransfer_v2
<
AccDataType
,
// SrcData,
AccDataType
,
// DstData,
decltype
(
c_partial_acc_block_m_n
),
// SrcDesc,
decltype
(
acc_thread_buf_load_desc
),
// DstDesc,
Sequence
<
1
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
,
// SliceLengths,
Sequence
<
0
,
1
>
,
// DimAccessOrder,
1
,
// SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// SrcScalarPerVector,
1
,
// SrcScalarStrideInVector,
false
// SrcResetCoordinateAfterRun,
>
{
c_partial_acc_block_m_n
,
make_multi_index
(
thread_m_cluster_id
,
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
)};
auto
acc_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
// SrcData,
CDataType
,
// DstData,
decltype
(
acc_thread_buf_store_desc
),
// SrcDesc,
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
// DstDesc,
CElementwiseOperation
,
// ElementwiseOperation,
Sequence
<
1
,
1
,
1
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
,
// SliceLengths,
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder,
3
,
// DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// DstScalarPerVector,
InMemoryDataOperationEnum
::
Set
,
// InMemoryDataOperationEnum DstInMemOp,
1
,
// DstScalarStrideInVector,
false
// DstResetCoordinateAfterRun,
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
thread_m_cluster_id
,
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
),
CElementwiseOperation
{}};
wg_barrier
.
wait_eq
(
reduction_idx
,
tile_acc_offset_end
-
tile_acc_offset_start
);
if
(
threadIdx
.
x
==
0
)
{
p_semaphore
[
reduction_idx
]
=
0
;
}
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
AccDataType
>
;
for
(
int
i_m
=
0
;
i_m
<
MReduceIters
;
i_m
++
)
{
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
acc_buf
.
Clear
();
for
(
auto
i
=
tile_acc_offset_start
;
i
<
tile_acc_offset_end
;
i
++
)
{
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
reinterpret_cast
<
AccDataType
*>
(
p_workspace
)
+
i
*
c_partial_acc_block_m_n
.
GetElementSpaceSize
(),
c_partial_acc_block_m_n
.
GetElementSpaceSize
());
acc_load
.
Run
(
c_partial_acc_block_m_n
,
c_partial_acc_buf
,
acc_thread_buf_load_desc
,
make_tuple
(
I0
,
I0
),
parcial_acc_buf
);
static_for
<
0
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
1
>
{}(
[
&
](
auto
i_vec
)
{
constexpr
auto
offset
=
acc_thread_buf_load_desc
.
CalculateOffset
(
make_tuple
(
0
,
i_vec
));
Accumulation
::
Calculate
(
acc_buf
(
Number
<
offset
>
{}),
parcial_acc_buf
[
Number
<
offset
>
{}]);
});
}
if
(
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
<
NPerBlock
)
{
acc_store
.
Run
(
acc_thread_buf_store_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
acc_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
if
constexpr
(
NReduceIters
!=
1
)
{
if
constexpr
(
i_n_reduce
!=
(
NReduceIters
-
1
))
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n
);
}
else
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n_reverse
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n_reverse
);
}
}
});
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_m
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_m
);
}
}
continue
;
}
}
// offset for last acc buffer of this block
uint32_t
block_acc_offset
=
(
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_block
(
block_idx
+
1
)
-
1
)
*
MPerBlock
*
NPerBlock
;
while
(
true
)
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
...
...
@@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
block_work_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
...
...
@@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
=
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
reinterpret_cast
<
AccDataType
*>
(
p_workspace
)
+
block_acc_offset
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
// LDS to global partial acc
auto
c_block_copy_lds_to_partial_acc
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CShuffleDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
false
,
// bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false
>
// bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
...
@@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else
if
(
is_sk_block
)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Atomic
)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
else
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
// constexpr offset
c_block_copy_lds_to_partial_acc
.
SetSrcSliceOrigin
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
make_tuple
(
0
,
0
,
0
,
0
));
c_block_copy_lds_to_partial_acc
.
SetDstSliceOrigin
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_tuple
(
MXdlPerWave
,
0
,
NXdlPerWave
,
0
));
c_block_copy_lds_to_partial_acc
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_partial_acc_buf
),
InMemoryDataOperationEnum
::
Set
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
c_partial_acc_buf
);
}
}
if
constexpr
(
access_id
<
num_access
-
1
)
...
...
@@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
if
(
is_sk_block
)
{
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
tile_idx
);
}
}
}
// shuffle c and write-out end
// exit condition
iter_end
-=
current_iter_length
;
if
(
iter_end
<=
iter_start
)
break
;
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
block_acc_offset
-=
MPerBlock
*
NPerBlock
;
}
// make sure next loop LDS is ready for use
block_sync_lds
();
}
}
}
// while loop
}
// for loop
}
template
<
bool
HasMainKBlockLoop
,
...
...
@@ -1574,19 +1912,43 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
Problem
&
problem
)
Problem
&
problem
,
void
*
p_workspace
)
{
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
uint32_t
iter_start
,
iter_end
;
bool
is_sk_block
,
is_dp_block
;
//, is_padding_block; //
, is_reduction_block;
bool
is_sk_block
,
is_dp_block
,
is_reduction_block
;
index_t
num_k_block_main_loop
;
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
,
problem
.
Streamk_sel
);
for
(
auto
block_idx
=
get_block_1d_id
();
block_idx
<
block_2_ctile_map_streamk
.
get_grid_dims
();
block_idx
+=
gridDim
.
x
)
...
...
@@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
num_k_block_main_loop
=
iter_end
-
iter_start
;
uint32_t
*
p_semaphore
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
p_workspace
)
+
block_2_ctile_map_streamk
.
get_workspace_size_for_acc
(
sizeof
(
AccDataType
)));
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
is_reduction_block
=
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
if
(
is_reduction_block
)
{
// descriptors
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
constexpr
auto
reduce_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
const
auto
reduce_thread_cluster_idx
=
reduce_desc
.
CalculateBottomIndex
(
make_multi_index
(
block_idx
));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
constexpr
auto
MReduceIters
=
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
constexpr
auto
NReduceIters
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
cluster_length_reduce
.
At
(
I1
)
*
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{});
constexpr
auto
acc_thread_buf_load_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{}));
constexpr
auto
acc_thread_buf_store_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{}));
constexpr
auto
c_partial_acc_block_m_n
=
GetPartialAccBlockDescriptor
();
constexpr
auto
partial_acc_load_step_n
=
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_load_step_n_reverse
=
make_multi_index
(
0
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_load_step_m
=
make_multi_index
(
cluster_length_reduce
.
At
(
I0
),
0
);
constexpr
auto
partial_acc_store_step_n
=
make_multi_index
(
0
,
0
,
0
,
cluster_length_reduce
.
At
(
I1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_store_step_n_reverse
=
make_multi_index
(
0
,
0
,
0
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_store_step_m
=
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I0
),
0
,
0
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
true
>
parcial_acc_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
true
>
acc_buf
;
// start to compute
auto
reduction_idx
=
block_idx
-
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
auto
spatial_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
reduction_idx
,
problem
.
M
,
problem
.
N
);
workgroup_barrier
wg_barrier
(
p_semaphore
);
uint32_t
tile_acc_offset_start
=
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_tile
(
reduction_idx
);
uint32_t
tile_acc_offset_end
=
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_tile
(
reduction_idx
+
1
);
uint32_t
expected_count
=
tile_acc_offset_end
-
tile_acc_offset_start
;
if
(
threadIdx
.
x
==
0
)
{
p_semaphore
[
reduction_idx
]
=
0
;
}
__syncthreads
();
auto
acc_load
=
ThreadwiseTensorSliceTransfer_v2
<
AccDataType
,
// SrcData,
AccDataType
,
// DstData,
decltype
(
c_partial_acc_block_m_n
),
// SrcDesc,
decltype
(
acc_thread_buf_load_desc
),
// DstDesc,
Sequence
<
1
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
,
// SliceLengths,
Sequence
<
0
,
1
>
,
// DimAccessOrder,
1
,
// SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// SrcScalarPerVector,
1
,
// SrcScalarStrideInVector,
false
// SrcResetCoordinateAfterRun,
>
{
c_partial_acc_block_m_n
,
make_multi_index
(
thread_m_cluster_id
,
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
)};
auto
acc_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
// SrcData,
CDataType
,
// DstData,
decltype
(
acc_thread_buf_store_desc
),
// SrcDesc,
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
// DstDesc,
CElementwiseOperation
,
// ElementwiseOperation,
Sequence
<
1
,
1
,
1
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
,
// SliceLengths,
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder,
3
,
// DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// DstScalarPerVector,
InMemoryDataOperationEnum
::
Set
,
// InMemoryDataOperationEnum DstInMemOp,
1
,
// DstScalarStrideInVector,
false
// DstResetCoordinateAfterRun,
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
thread_m_cluster_id
,
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
),
CElementwiseOperation
{}};
#if 0
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
#endif
if
(
threadIdx
.
x
==
0
)
{
atomicAdd
(
&
p_semaphore
[
reduction_idx
],
1
);
}
wg_barrier
.
wait_eq
(
p_semaphore
[
reduction_idx
],
expected_count
);
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
AccDataType
>
;
for
(
int
i_m
=
0
;
i_m
<
MReduceIters
;
i_m
++
)
{
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
acc_buf
.
Clear
();
for
(
auto
i
=
tile_acc_offset_start
;
i
<
tile_acc_offset_end
;
i
++
)
{
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
reinterpret_cast
<
AccDataType
*>
(
p_workspace
)
+
i
*
c_partial_acc_block_m_n
.
GetElementSpaceSize
(),
c_partial_acc_block_m_n
.
GetElementSpaceSize
());
acc_load
.
Run
(
c_partial_acc_block_m_n
,
c_partial_acc_buf
,
acc_thread_buf_load_desc
,
make_tuple
(
I0
,
I0
),
parcial_acc_buf
);
static_for
<
0
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
1
>
{}(
[
&
](
auto
i_vec
)
{
constexpr
auto
offset
=
acc_thread_buf_load_desc
.
CalculateOffset
(
make_tuple
(
0
,
i_vec
));
Accumulation
::
Calculate
(
acc_buf
(
Number
<
offset
>
{}),
parcial_acc_buf
[
Number
<
offset
>
{}]);
});
}
if
(
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
<
NPerBlock
)
{
acc_store
.
Run
(
acc_thread_buf_store_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
acc_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
if
constexpr
(
NReduceIters
!=
1
)
{
if
constexpr
(
i_n_reduce
!=
(
NReduceIters
-
1
))
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n
);
}
else
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n_reverse
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n_reverse
);
}
}
});
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_m
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_m
);
}
}
continue
;
}
}
// offset for last acc buffer of this block
uint32_t
block_acc_offset
=
(
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_block
(
block_idx
+
1
)
-
1
)
*
MPerBlock
*
NPerBlock
;
while
(
true
)
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
...
...
@@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
block_work_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
...
...
@@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
=
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared_0
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
reinterpret_cast
<
AccDataType
*>
(
p_workspace
)
+
block_acc_offset
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
// LDS to global partial acc
auto
c_block_copy_lds_to_partial_acc
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CShuffleDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
false
,
// bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false
>
// bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
...
@@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else
if
(
is_sk_block
)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Atomic
)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
else
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
// constexpr offset
c_block_copy_lds_to_partial_acc
.
SetSrcSliceOrigin
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
make_tuple
(
0
,
0
,
0
,
0
));
c_block_copy_lds_to_partial_acc
.
SetDstSliceOrigin
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_tuple
(
MXdlPerWave
,
0
,
NXdlPerWave
,
0
));
c_block_copy_lds_to_partial_acc
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_partial_acc_buf
),
InMemoryDataOperationEnum
::
Set
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
c_partial_acc_buf
);
}
}
if
constexpr
(
access_id
<
num_access
-
1
)
{
...
...
@@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
});
}
// exit condition
iter_end
-=
current_iter_length
;
if
(
iter_end
<=
iter_start
)
break
;
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
block_acc_offset
-=
MPerBlock
*
NPerBlock
;
}
// make sure next loop LDS is ready for use
block_sync_lds
();
}
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
if
(
is_sk_block
)
{
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
0
);
}
}
}
}
...
...
include/ck/utility/amd_wmma.hpp
View file @
563c1e23
...
...
@@ -13,6 +13,11 @@ namespace ck {
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
...
...
@@ -99,7 +104,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
#if defined(__gfx11__)
#if defined(__gfx11__)
|| defined(__gfx12__)
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
Opsel
);
...
...
@@ -261,10 +266,6 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
;
...
...
include/ck/utility/loop_scheduler.hpp
View file @
563c1e23
...
...
@@ -8,7 +8,6 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
...
...
include/ck_tile/core.hpp
View file @
563c1e23
...
...
@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
...
...
@@ -62,6 +63,7 @@
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
563c1e23
...
...
@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
CK_TILE_DEVICE
void
lds_load_fence
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt lgkmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
template
<
typename
scalar_type
,
index_t
N
,
bool
pre_nop
=
false
>
struct
buffer_atomic_add_if
;
template
<
bool
pre_nop
>
struct
buffer_atomic_add_if
<
bf16_t
,
2
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
/*s_offset*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3
\n
"
"s_mov_b64 exec %5"
:
:
"v"
(
v_offset
),
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"s"
(
res
.
xy
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
:
"memory"
);
}
};
template
<
typename
scalar_type
,
index_t
N
,
bool
pre_nop
=
false
>
struct
buffer_atomic_add
;
template
<
bool
pre_nop
>
struct
buffer_atomic_add
<
bf16_t
,
2
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
/*s_offset*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag = 1*/
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
:
"v"
(
v_offset
),
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"s"
(
res
.
xy
),
"n"
(
i_offset
)
:
"memory"
);
}
};
namespace
impl
{
// below type indicate the data type used for buffer load inline asm
// clang-format off
...
...
@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
CK_TILE_DEVICE
auto
async_load_fence_raw
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
// buffer load i8
CK_TILE_DEVICE_EXTERN
int8_t
llvm_amdgcn_raw_buffer_load_i8
(
int32x4_t
srsrc
,
...
...
@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_atomic_add_raw
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
index_t
dst_linear_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
,
bool_constant
<
pre_nop
>
=
{})
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
index_t
dst_linear_addr_offset
=
dst_linear_element_offset
*
sizeof
(
T
);
if
constexpr
(
oob_conditional_check
)
{
buffer_atomic_add_if
<
T
,
N
,
pre_nop
>
{}(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
,
dst_linear_addr_offset
,
dst_thread_element_valid
);
}
else
{
buffer_atomic_add
<
T
,
N
,
pre_nop
>
{}(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
,
dst_linear_addr_offset
,
1
);
}
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
...
...
include/ck_tile/core/arch/arch.hpp
View file @
563c1e23
...
...
@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif
}
CK_TILE_DEVICE
void
block_sync_load_raw
(
index_t
cnt
=
0
)
{
#ifdef __gfx12__
asm
volatile
(
"s_wait_loadcnt %0
\n
"
"s_barrier_signal -1
\n
"
"s_barrier_wait -1"
:
:
"n"
(
cnt
)
:
"memory"
);
#else
asm
volatile
(
"s_waitcnt vmcnt(%0)
\n
"
"s_barrier"
:
:
"n"
(
cnt
)
:
"memory"
);
#endif
}
CK_TILE_DEVICE
void
block_sync_lds_direct_load
()
{
asm
volatile
(
"\
...
...
include/ck_tile/core/arch/utility.hpp
View file @
563c1e23
...
...
@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
#endif
}
template
<
typename
T
>
CK_TILE_DEVICE
auto
flag_to_exec
(
const
T
&
v_flag
)
{
static_assert
(
sizeof
(
T
)
==
4
);
// per-thread v_flag store into 2x sgpr
uint32x2_t
exec_flag
;
asm
volatile
(
"v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
:
[
s_exec_flag
]
"=s"
(
exec_flag
)
:
[
v_flag
]
"v"
(
v_flag
));
return
exec_flag
;
}
template
<
typename
X
,
typename
Y
>
CK_TILE_DEVICE
auto
cmp_lt_to_exec
(
const
X
&
x
,
const
Y
&
y
)
{
static_assert
(
sizeof
(
X
)
==
4
&&
sizeof
(
Y
)
==
4
);
// per-thread cmp store into 2x sgpr
uint32x2_t
exec_flag
;
asm
volatile
(
"v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
:
[
s_exec_flag
]
"=s"
(
exec_flag
)
:
[
v_x
]
"v"
(
x
),
[
v_y
]
"v"
(
y
));
return
exec_flag
;
}
}
// namespace ck_tile
include/ck_tile/core/config.hpp
View file @
563c1e23
...
...
@@ -64,6 +64,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
...
...
@@ -225,3 +226,7 @@
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
include/ck_tile/core/numeric/bfloat16.hpp
View file @
563c1e23
...
...
@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
truncate_with_nan
,
truncate
,
standard_asm
,
rta_asm
,
// round to nearest away
};
template
<
bf16_rounding_mode
rounding
=
...
...
@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f)
return
uint16_t
(
u
.
int32
);
}
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t
float_to_bf16_rta_asm
(
float
f
)
{
return
float_to_bf16_rtn_raw
(
f
);
}
CK_TILE_DEVICE
uint16_t
float_to_bf16_rta_asm
(
float
f
)
{
union
{
float
fp32
;
struct
{
uint16_t
lo
;
uint16_t
hi
;
};
}
u
=
{
f
};
const
uint32_t
low_nan
=
0x7fff
;
const
uint32_t
hi_nan
=
0x7fff0000
;
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
check_nan
;
asm
volatile
(
"v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x]
\n
"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1
\n
"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
:
[
s_cnan
]
"+s"
(
check_nan
),
[
v_x
]
"+v"
(
u
.
fp32
)
:
[
v_blo
]
"v"
(
low_nan
),
[
v_bhi
]
"v"
(
hi_nan
));
// Note: in above code snipet, we use hi 16 bit
return
u
.
hi
;
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_nan_raw
(
float
f
)
...
...
@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return
float_to_bf16_rtn_asm
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
truncate_with_nan
)
return
float_to_bf16_truc_nan_raw
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
rta_asm
)
return
float_to_bf16_rta_asm
(
f
);
else
return
float_to_bf16_truc_raw
(
f
);
}
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
563c1e23
...
...
@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
set
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
this
->
template
atomic_add
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
atomic_add
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
this
->
template
atomic_max
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
atomic_max
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
auto
tmp
=
this
->
template
get
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set_raw
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
this
->
template
atomic_add_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
...
...
@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
...
...
@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global,
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_add_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
// using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
get_address_space
()
==
address_space_enum
::
global
,
"only support global mem"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
x
,
p_data_
,
i
,
linear_offset
,
is_valid_element
,
buffer_size_
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
563c1e23
...
...
@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
DistributedTensor_
,
...
...
@@ -51,15 +55,35 @@ template <typename DistributedTensor_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
dst_tile
,
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
dst_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
DistributedTensor_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
dst_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
/**
...
...
@@ -76,6 +100,7 @@ template <typename T,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
...
@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
T
,
...
...
@@ -95,6 +121,7 @@ template <typename T,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
...
@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
LdsTileWindow_
,
...
...
@@ -114,6 +142,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
...
...
@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
LdsTileWindow_
,
...
...
@@ -134,6 +166,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
...
...
@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
...
...
include/ck_tile/core/tensor/shuffle_tile.hpp
View file @
563c1e23
...
...
@@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
}
else
{
// NOT implemented
static_assert
(
false
,
"The shuffle should always happen!"
);
}
}
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment