Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
47a2a8e1
Unverified
Commit
47a2a8e1
authored
Nov 28, 2023
by
arai713
Committed by
GitHub
Nov 28, 2023
Browse files
Merge branch 'develop' into transpose_5d
parents
2bf883d6
ae5e5181
Changes
65
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2806 additions
and
185 deletions
+2806
-185
include/ck/ck.hpp
include/ck/ck.hpp
+3
-0
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+7
-0
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+2
-2
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp
.../block/thread_group_tensor_slice_transfer_direct_load.hpp
+314
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
+414
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
.../device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
+392
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+18
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
...gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
+991
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+7
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp
...ration/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp
+101
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+11
-0
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+37
-0
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+20
-0
include/ck/utility/synchronization.hpp
include/ck/utility/synchronization.hpp
+9
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+189
-141
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+34
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+147
-42
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+5
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
...shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
+54
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp
...shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp
+51
-0
No files found.
include/ck/ck.hpp
View file @
47a2a8e1
...
...
@@ -134,6 +134,9 @@
// inner product using V_DOT with DPP8 modifiers
#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1
// set stochastic rounding as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
...
include/ck/host_utility/device_prop.hpp
View file @
47a2a8e1
...
...
@@ -58,4 +58,11 @@ inline bool is_xdl_supported()
ck
::
get_device_name
()
==
"gfx942"
;
}
inline
bool
is_lds_direct_load_supported
()
{
// Check if direct loads from global memory to LDS are supported.
return
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
}
}
// namespace ck
include/ck/stream_config.hpp
View file @
47a2a8e1
...
...
@@ -11,6 +11,6 @@ struct StreamConfig
hipStream_t
stream_id_
=
nullptr
;
bool
time_kernel_
=
false
;
int
log_level_
=
0
;
int
cold_niters_
=
50
;
int
nrepeat_
=
20
0
;
int
cold_niters_
=
1
;
int
nrepeat_
=
1
0
;
};
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp
0 → 100644
View file @
47a2a8e1
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
/**
* Transfer that uses direct load instructions to copy data from global to LDS memory.
*
* Traditional loads first copy data from global to registers, and then from registers to LDS.
* Direct loads do not need an intermediate step, data is copied directly from global to LDS,
* without the use of additional registers.
*
* However, the instruction has limitations:
* - each thread must copy exactly a single DWORD - 4 bytes;
* - threads within a single wavefront must write consecutive DWORDS into LDS,
* (data in global do not need to be contiguous, each thread might have its own offset).
*
* To make sure that all the transfers finished, the `waitcnt` instruction must be used with
* `vmcnt` instead of `lgkmcnt`.
*
* Limitations of the transfer class:
* - `SrcData` must be the same as `DstData` - no possibility to convert the data type in flight;
* - `DstVectorDim` must be the last dimension;
* - `SrcVectorDim` must be the last dimension if `ScalarPerVector` is greater than 1;
* - `ScalarPerVector` times the number of bytes of `DstData` must be equal to a single DWORD = 4B
* (for examlpe if `DstData` is fp32, then `ScalarPerVector` must be 1; if `DstData` is fp16,
* `ScalarPerVector` must be 2);
* - if `ScalarPerVector` is greater than 1, the contiguous dimension in src and dst must be
* the same dimension;
* - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64,
* they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way
* to guarantee that.
*/
template
<
typename
ThreadGroup
,
typename
BlockSliceLengths
,
typename
ThreadClusterLengths
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
ScalarPerVector
>
struct
ThreadGroupTensorSliceTransfer_DirectLoad
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
block_slice_lengths
=
BlockSliceLengths
{};
static
constexpr
auto
thread_cluster_lengths
=
ThreadClusterLengths
{};
static
constexpr
auto
thread_single_load_size
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
static
constexpr
auto
thread_steps
=
thread_cluster_lengths
*
thread_single_load_size
;
static
constexpr
auto
thread_slice_lengths
=
block_slice_lengths
/
thread_steps
;
static
__device__
constexpr
bool
AreThreadClusterLengthsValid
()
{
// Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to
// LDS by the threads from a single wavefront.
// Examples (assuming 64 threads in a wavefront, 128 in a thread block):
// 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
// data type = fp32 -> ScalarPerVector = 1
// INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31
// write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of
// [0, 4, 0].
// VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration,
// threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs).
// 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
// data type = fp16 -> ScalarPerVector = 2
// NOTE: ThreadClusterLengths must take into account that each thread writes two
// elements (single DWORD) along the contiguous dimension.
// INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write
// 8 * 2 elements of K1PerBlock and there are only 8;
// ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31
// write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32
// writes [1, 0, 0] instead of [0, 8, 0].
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
// elements = 64 consecutive DWORDs.
int
num_contiguous_dwords
=
1
;
bool
is_contiguous
=
true
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
(
is_contiguous
)
{
num_contiguous_dwords
*=
thread_cluster_lengths
[
nDim
-
i
-
1
];
}
if
(
thread_slice_lengths
[
nDim
-
i
-
1
]
>
1
)
{
is_contiguous
=
false
;
}
});
constexpr
index_t
wavefront_size
=
get_warp_size
();
const
bool
wave_contiguous
=
num_contiguous_dwords
%
wavefront_size
==
0
;
bool
thread_slice_lengths_correct
=
true
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
(
thread_slice_lengths
[
i
]
<=
0
)
{
thread_slice_lengths_correct
=
false
;
}
});
return
wave_contiguous
&&
thread_slice_lengths_correct
;
}
__device__
constexpr
ThreadGroupTensorSliceTransfer_DirectLoad
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
{
static_assert
(
ck
::
is_same_v
<
SrcData
,
DstData
>
,
"Direct load transfer does not support datatypes conversion. Source and "
"destination data types must be the same."
);
static_assert
(
DstVectorDim
==
nDim
-
1
,
"Direct load transfer requires the destination vector dimension to be the last one."
);
static_assert
(
ScalarPerVector
==
1
||
SrcVectorDim
==
DstVectorDim
,
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination."
);
constexpr
auto
dword_bytes
=
4
;
constexpr
auto
bytes_per_thread_load
=
ScalarPerVector
*
sizeof
(
SrcData
);
static_assert
(
bytes_per_thread_load
==
dword_bytes
,
"Direct load transfer requires each thread to load exactly a single "
"DWORD of data."
);
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
(),
"Inconsistent number of dimensions across lengths and descriptors."
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"The number of threads cannot be less than the number of elements in "
"thread cluster lengths."
);
static_assert
(
AreThreadClusterLengthsValid
(),
"Thread cluster lengths are incorrect. They must be set in a way that allows a single "
"wavefront to write contiguous DWORDs into LDS memory. "
);
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_single_load_size
;
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
src_slice_origin_
=
src_slice_origin_idx
;
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
dst_slice_origin_
=
dst_slice_origin_idx
;
}
__device__
void
ResetDstSliceWindow
(
const
DstDesc
&
dst_desc
)
{
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_
);
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
"Source data must come from a global memory buffer."
);
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"Destination data must be stored in an LDS memory buffer."
);
static_assert
(
ck
::
is_same_v
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>
,
"SrcBuffer and SrcData data types must be consistent."
);
static_assert
(
ck
::
is_same_v
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>
,
"DstBuffer and DstData data types must be consistent."
);
constexpr
auto
dst_access_lengths
=
thread_slice_lengths
;
const
auto
dst_forward_steps
=
generate_steps
(
dst_desc
,
1
);
const
auto
dst_backward_steps
=
generate_steps
(
dst_desc
,
-
1
);
const
auto
src_forward_steps
=
generate_steps
(
src_desc
,
1
);
const
auto
src_backward_steps
=
generate_steps
(
src_desc
,
-
1
);
// Loop over the destination block and copy data.
static_ford
<
decltype
(
dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
const
auto
src_offset
=
src_coord_
.
GetOffset
();
const
auto
dst_offset
=
dst_coord_
.
GetOffset
();
// Check if src data is not in the logic padding area.
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
src_buf
.
template
DirectCopyToLds
<
remove_cvref_t
<
decltype
(
dst_buf
)>,
ScalarPerVector
>
(
dst_buf
,
src_offset
,
dst_offset
,
is_src_valid
);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
dst_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_dst_access_idx
[
j
]
==
dst_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// Decide whether to move forward or backward.
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
i
]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
i
]);
}
else
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
i
]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
i
]);
}
}
});
});
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow
(
dst_desc
);
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
src_slice_origin_
=
src_slice_origin_
+
step
;
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_
);
}
template
<
typename
DescType
>
__device__
auto
generate_steps
(
const
DescType
&
desc
,
int
sign
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
Index
step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
sign
*
thread_steps
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
desc
,
step_idx
);
},
Number
<
nDim
>
{});
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{});
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
Index
src_slice_origin_
;
Index
dst_slice_origin_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
0 → 100644
View file @
47a2a8e1
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferScalarPerVector
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferScalarPerVector
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v4
,
typename
ComputeDataType
=
EDataType
>
struct
DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
:
public
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
GridwiseGemm
=
GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
GemmSpec
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferScalarPerVector
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferScalarPerVector
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load
<
GridwiseGemm
,
ADataType
,
BDataType
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
typename
GridwiseGemm
::
AGridDesc_AK0_M_AK1
,
typename
GridwiseGemm
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
Block2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
const
auto
K
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I1
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
(
!
ck
::
is_lds_direct_load_supported
())
{
return
false
;
}
// Check vector load/store.
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// Check vector load of A.
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
if
(
arg
.
MRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// Check vector load of B.
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
if
(
arg
.
NRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// Check vector load of Ds.
// For now, only the RowMajor layout is supported.
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
all_valid
=
false
;
}
});
if
(
!
all_valid
)
{
return
false
;
}
// Check vector load of E.
// For now, only the RowMajor layout is supported.
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
v4
,
"v4"
}};
// clang-format off
str
<<
"DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferScalarPerVector
<<
", "
<<
BBlockTransferScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
0 → 100644
View file @
47a2a8e1
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferScalarPerVector
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferScalarPerVector
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v4
,
typename
ComputeDataType
=
EDataType
>
struct
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
:
public
DeviceGemm
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
static
constexpr
auto
I1
=
Number
<
1
>
{};
using
GridwiseGemm
=
GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
<
ALayout
,
BLayout
,
ck
::
Tuple
<>
,
ELayout
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<>
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
GemmSpec
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferScalarPerVector
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferScalarPerVector
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load
<
GridwiseGemm
,
ADataType
,
BDataType
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
typename
GridwiseGemm
::
AGridDesc_AK0_M_AK1
,
typename
GridwiseGemm
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
Block2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
// arg.p_ds_grid_,
ck
::
Tuple
<>
{},
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
const
auto
K
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I1
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
(
!
ck
::
is_lds_direct_load_supported
())
{
return
false
;
}
// Check vector load/store.
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// Check vector load of A.
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
if
(
arg
.
MRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// Check vector load of B.
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
if
(
arg
.
NRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// Check vector load of E.
// For now, only the RowMajor layout is supported.
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
using
EmptyDsPointers
=
std
::
array
<
const
void
*
,
0
>
;
using
EmptyDsStrides
=
std
::
array
<
ck
::
index_t
,
0
>
;
return
Argument
{
p_a
,
p_b
,
EmptyDsPointers
{},
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
EmptyDsStrides
{},
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_e
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
using
EmptyDsPointers
=
std
::
array
<
const
void
*
,
0
>
;
using
EmptyDsStrides
=
std
::
array
<
ck
::
index_t
,
0
>
;
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
EmptyDsPointers
{},
p_e
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
EmptyDsStrides
{},
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
v4
,
"v4"
}};
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffle_LdsDirectLoad"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferScalarPerVector
<<
", "
<<
BBlockTransferScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
47a2a8e1
...
...
@@ -281,6 +281,24 @@ struct ConvertF8SR
}
};
struct
ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
||
is_same
<
Y
,
bf8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
f8_convert_rne
<
Y
>
(
x
);
}
};
struct
Scale
{
__host__
__device__
Scale
(
float
scale
)
:
scale_
(
scale
)
{}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
0 → 100644
View file @
47a2a8e1
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
47a2a8e1
...
...
@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp"
namespace
ck
{
...
...
@@ -14,6 +15,8 @@ enum struct PipelineVersion
{
v1
,
v2
,
// v3 is only used in the Stream-K implementation.
v4
,
};
template
<
PipelineVersion
PipelineVer
,
...
...
@@ -36,6 +39,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
return
GridwiseGemmPipeline_v2
{};
}
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
v4
)
{
return
GridwiseGemmPipeline_v4
<
NumPrefetch
>
{};
}
else
{
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp
0 → 100644
View file @
47a2a8e1
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
index_t
NumPrefetch
>
struct
GridwiseGemmPipeline_v4
;
// 1-stage prefetch
template
<
>
struct
GridwiseGemmPipeline_v4
<
1
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds_direct_load
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds_direct_load
();
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds_direct_load
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
47a2a8e1
...
...
@@ -996,6 +996,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
problem
.
K0
%
K0PerBlock
==
0
))
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
problem
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
47a2a8e1
...
...
@@ -944,4 +944,41 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
#endif
}
// Direct loads from global to LDS.
__device__
void
llvm_amdgcn_raw_buffer_load_lds
(
int32x4_t
rsrc
,
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
,
index_t
size
,
index_t
voffset
,
index_t
soffset
,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
template
<
typename
T
,
index_t
NumElemsPerThread
>
__device__
void
amd_direct_load_global_to_lds
(
const
T
*
global_base_ptr
,
const
index_t
global_offset
,
T
*
lds_base_ptr
,
const
index_t
lds_offset
,
const
bool
is_valid
,
const
index_t
src_element_space_size
)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr
auto
dword_bytes
=
4
;
constexpr
auto
bytes_per_thread
=
sizeof
(
T
)
*
NumElemsPerThread
;
static_assert
(
bytes_per_thread
==
dword_bytes
);
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
global_base_ptr
));
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
);
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
lds_base_ptr
+
lds_offset
));
llvm_amdgcn_raw_buffer_load_lds
(
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
}
}
// namespace ck
include/ck/utility/dynamic_buffer.hpp
View file @
47a2a8e1
...
...
@@ -173,6 +173,26 @@ struct DynamicBuffer
}
}
template
<
typename
DstBuffer
,
index_t
NumElemsPerThread
>
__host__
__device__
void
DirectCopyToLds
(
DstBuffer
&
dst_buf
,
index_t
src_offset
,
index_t
dst_offset
,
bool
is_valid_element
)
const
{
// Copy data from global to LDS memory using direct loads.
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
"Source data must come from a global memory buffer."
);
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"Destination data must be stored in an LDS memory buffer."
);
amd_direct_load_global_to_lds
<
T
,
NumElemsPerThread
>
(
p_data_
,
src_offset
,
dst_buf
.
p_data_
,
dst_offset
,
is_valid_element
,
element_space_size_
);
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
...
...
include/ck/utility/synchronization.hpp
View file @
47a2a8e1
...
...
@@ -19,6 +19,15 @@ __device__ void block_sync_lds()
#endif
}
__device__
void
block_sync_lds_direct_load
()
{
asm
volatile
(
"\
s_waitcnt vmcnt(0)
\n
\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
}
__device__
void
s_nop
()
{
#if 1
...
...
include/ck/utility/type_convert.hpp
View file @
47a2a8e1
...
...
@@ -95,9 +95,113 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert fp32 to fp8
// Declare a template function for fp8 conversion using SR
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// Declare a template function for fp8 conversion using RNE
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_rne
(
X
x
);
// convert fp32 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
max_fp8
=
240.0
f
;
...
...
@@ -124,6 +228,80 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
#endif
}
// convert fp16 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_rne
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
float
>
(
float
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_rne
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
{
#if defined CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
#endif
}
// convert fp8 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f8_t
>
(
f8_t
x
)
...
...
@@ -174,17 +352,10 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
type_convert
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#if defined CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
f8_convert_nre
<
f8_t
>
(
x
);
#endif
}
...
...
@@ -205,26 +376,10 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#if defined CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
f8_convert_rne
<
bf8_t
>
(
x
);
#endif
}
...
...
@@ -248,17 +403,10 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
type_convert
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
#if defined CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
f8_convert_rne
<
bf8_t
>
(
x
);
#endif
}
...
...
@@ -331,104 +479,4 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
// Declare a template function for fp8 conversion using SR
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
47a2a8e1
...
...
@@ -227,6 +227,10 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
...
...
@@ -289,6 +293,26 @@ void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP64
void
add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances
(
...
...
@@ -382,6 +406,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
@@ -391,6 +417,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
@@ -400,6 +428,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
@@ -409,6 +439,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
}
}
#ifdef CK_ENABLE_FP16
...
...
@@ -439,6 +471,8 @@ struct DeviceOperationInstanceFactory<
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
47a2a8e1
...
...
@@ -58,7 +58,12 @@ endfunction(add_instance_library INSTANCE_NAME)
file
(
GLOB dir_list LIST_DIRECTORIES true *
)
set
(
CK_DEVICE_INSTANCES
)
set
(
CK_DEVICE_OTHER_INSTANCES
)
set
(
CK_DEVICE_GEMM_INSTANCES
)
set
(
CK_DEVICE_CONV_INSTANCES
)
set
(
CK_DEVICE_MHA_INSTANCES
)
set
(
CK_DEVICE_CONTRACTION_INSTANCES
)
set
(
CK_DEVICE_REDUCTION_INSTANCES
)
FOREACH
(
subdir_path
${
dir_list
}
)
set
(
target_dir
)
IF
(
IS_DIRECTORY
"
${
subdir_path
}
"
)
...
...
@@ -122,7 +127,19 @@ FOREACH(subdir_path ${dir_list})
if
((
add_inst EQUAL 1
))
get_filename_component
(
target_dir
${
subdir_path
}
NAME
)
add_subdirectory
(
${
target_dir
}
)
list
(
APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_
${
target_dir
}
_instance>
)
if
(
"
${
cmake_instance
}
"
MATCHES
"gemm"
)
list
(
APPEND CK_DEVICE_GEMM_INSTANCES $<TARGET_OBJECTS:device_
${
target_dir
}
_instance>
)
elseif
(
"
${
cmake_instance
}
"
MATCHES
"conv"
)
list
(
APPEND CK_DEVICE_CONV_INSTANCES $<TARGET_OBJECTS:device_
${
target_dir
}
_instance>
)
elseif
(
"
${
cmake_instance
}
"
MATCHES
"mha"
)
list
(
APPEND CK_DEVICE_MHA_INSTANCES $<TARGET_OBJECTS:device_
${
target_dir
}
_instance>
)
elseif
(
"
${
cmake_instance
}
"
MATCHES
"contr"
)
list
(
APPEND CK_DEVICE_CONTRACTION_INSTANCES $<TARGET_OBJECTS:device_
${
target_dir
}
_instance>
)
elseif
(
"
${
cmake_instance
}
"
MATCHES
"reduce"
)
list
(
APPEND CK_DEVICE_REDUCTION_INSTANCES $<TARGET_OBJECTS:device_
${
target_dir
}
_instance>
)
else
()
list
(
APPEND CK_DEVICE_OTHER_INSTANCES $<TARGET_OBJECTS:device_
${
target_dir
}
_instance>
)
endif
()
message
(
"add_instance_directory
${
subdir_path
}
"
)
else
()
message
(
"skip_instance_directory
${
subdir_path
}
"
)
...
...
@@ -130,18 +147,14 @@ FOREACH(subdir_path ${dir_list})
ENDIF
()
ENDFOREACH
()
add_library
(
device_operations STATIC
${
CK_DEVICE_INSTANCES
}
)
add_library
(
composablekernels::device_operations ALIAS device_operations
)
set
(
DEV_OPS_INC_DIRS
${
PROJECT_SOURCE_DIR
}
/include/ck/
${
PROJECT_SOURCE_DIR
}
/library/include/ck/
)
target_compile_features
(
device_operations PUBLIC
)
set_target_properties
(
device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_operations PUBLIC
if
(
CK_DEVICE_OTHER_INSTANCES
)
add_library
(
device_other_operations STATIC
${
CK_DEVICE_OTHER_INSTANCES
}
)
add_library
(
composablekernels::device_other_operations ALIAS device_other_operations
)
target_compile_features
(
device_other_operations PUBLIC
)
set_target_properties
(
device_other_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_other_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/utility>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_description>
...
...
@@ -157,23 +170,115 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/utility>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/quantization>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/softmax>
)
rocm_install
(
TARGETS device_other_operations
EXPORT device_other_operationsTargets
)
rocm_install
(
EXPORT device_other_operationsTargets
FILE composable_kerneldevice_other_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
endif
()
if
(
CK_DEVICE_GEMM_INSTANCES
)
add_library
(
device_gemm_operations STATIC
${
CK_DEVICE_GEMM_INSTANCES
}
)
add_library
(
composablekernels::device_gemm_operations ALIAS device_gemm_operations
)
target_compile_features
(
device_gemm_operations PUBLIC
)
set_target_properties
(
device_gemm_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_gemm_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu>
)
rocm_install
(
TARGETS device_gemm_operations
EXPORT device_gemm_operationsTargets
)
rocm_install
(
EXPORT device_gemm_operationsTargets
FILE composable_kerneldevice_gemm_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
endif
()
if
(
CK_DEVICE_CONV_INSTANCES
)
add_library
(
device_conv_operations STATIC
${
CK_DEVICE_CONV_INSTANCES
}
)
add_library
(
composablekernels::device_conv_operations ALIAS device_conv_operations
)
target_compile_features
(
device_conv_operations PUBLIC
)
set_target_properties
(
device_conv_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_conv_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd>
)
rocm_install
(
TARGETS device_conv_operations
EXPORT device_conv_operationsTargets
)
rocm_install
(
EXPORT device_conv_operationsTargets
FILE composable_kerneldevice_conv_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
endif
()
if
(
CK_DEVICE_MHA_INSTANCES
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
target_compile_features
(
device_mha_operations PUBLIC
)
set_target_properties
(
device_mha_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_mha_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/mha>
)
rocm_install
(
TARGETS device_mha_operations
EXPORT device_mha_operationsTargets
)
rocm_install
(
EXPORT device_mha_operationsTargets
FILE composable_kerneldevice_mha_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
endif
()
if
(
CK_DEVICE_CONTRACTION_INSTANCES
)
add_library
(
device_contraction_operations STATIC
${
CK_DEVICE_CONTRACTION_INSTANCES
}
)
add_library
(
composablekernels::device_contraction_operations ALIAS device_contraction_operations
)
target_compile_features
(
device_contraction_operations PUBLIC
)
set_target_properties
(
device_contraction_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_contraction_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/contraction>
)
rocm_install
(
TARGETS device_contraction_operations
EXPORT device_contraction_operationsTargets
)
rocm_install
(
EXPORT device_contraction_operationsTargets
FILE composable_kerneldevice_contraction_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
endif
()
if
(
CK_DEVICE_REDUCTION_INSTANCES
)
add_library
(
device_reduction_operations STATIC
${
CK_DEVICE_REDUCTION_INSTANCES
}
)
add_library
(
composablekernels::device_reduction_operations ALIAS device_reduction_operations
)
target_compile_features
(
device_reduction_operations PUBLIC
)
set_target_properties
(
device_reduction_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_reduction_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/reduce>
)
#once new arches are enabled make this an option on the main cmake file
# and pass down here to be exported
target_compile_options
(
device_operations PRIVATE
--offload-arch=gfx908
--offload-arch=gfx90a
)
# install(TARGETS device_operations LIBRARY DESTINATION lib)
rocm_install
(
TARGETS device_operations
EXPORT device_operationsTargets
)
rocm_install
(
DIRECTORY
${
DEV_OPS_INC_DIRS
}
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck
)
rocm_install
(
EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake
)
rocm_install
(
TARGETS device_reduction_operations
EXPORT device_reduction_operationsTargets
)
rocm_install
(
EXPORT device_reduction_operationsTargets
FILE composable_kerneldevice_reduction_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
endif
()
add_library
(
device_operations INTERFACE
)
target_link_libraries
(
device_operations INTERFACE
device_contraction_operations
device_conv_operations
device_gemm_operations
device_other_operations
device_reduction_operations
utility
)
set
(
DEV_OPS_INC_DIRS
${
PROJECT_SOURCE_DIR
}
/include/ck/
${
PROJECT_SOURCE_DIR
}
/library/include/ck/
)
rocm_install
(
DIRECTORY
${
DEV_OPS_INC_DIRS
}
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck
)
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
47a2a8e1
...
...
@@ -13,6 +13,10 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
...
...
@@ -41,6 +45,7 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
0 → 100644
View file @
47a2a8e1
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
using
device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
// ##################################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
void
add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp
0 → 100644
View file @
47a2a8e1
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// ##################################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
void
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment