Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5f4c1ddb
Commit
5f4c1ddb
authored
Nov 01, 2023
by
Bartlomiej Wroblewski
Browse files
Clean the code and comments
parent
0661e8d2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
80 additions
and
152 deletions
+80
-152
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+2
-4
example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp
example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp
+0
-14
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp
.../block/thread_group_tensor_slice_transfer_direct_load.hpp
+20
-39
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
+7
-14
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
.../device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
+5
-13
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
...gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
+46
-68
No files found.
example/01_gemm/CMakeLists.txt
View file @
5f4c1ddb
...
...
@@ -44,7 +44,7 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
# FIXME: re-enable this examp
l
e as test when SWDEV-335738 is fixed
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
...
...
@@ -58,9 +58,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
if
(
GPU_TARGETS MATCHES
"gfx90a"
)
add_example_executable
(
example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32
)
endif
()
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32
)
endif
()
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
...
...
example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp
View file @
5f4c1ddb
...
...
@@ -6,26 +6,12 @@
#include "common.hpp"
#define USING_DIRECT_LOADS 1
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#if USING_DIRECT_LOADS
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#endif
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
using
F32
=
float
;
using
ADataType
=
F32
;
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp
View file @
5f4c1ddb
...
...
@@ -67,17 +67,13 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"The number of threads cannot be less than the number of elements in "
"thread cluster lengths."
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
;
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
...
@@ -103,11 +99,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
!=
thread_cluster_desc_
.
GetElementSize
()
&&
ThreadGroup
::
GetThreadId
()
>=
thread_cluster_desc_
.
GetElementSize
())
{
return
;
}
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
"Source data must come from a global memory buffer."
);
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
...
...
@@ -120,21 +111,19 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"DstBuffer and DstData data types must be consistent."
);
constexpr
auto
dst_access_lengths
=
thread_slice_lengths
;
constexpr
auto
dst_dim_access_order
=
Sequence
<
0
,
1
,
2
>
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
constexpr
auto
dst_access_lengths
=
thread_slice_lengths
;
const
auto
dst_forward_steps
=
generate_steps
(
dst_desc
,
1
);
const
auto
dst_backward_steps
=
generate_steps
(
dst_desc
,
-
1
);
const
auto
src_forward_steps
=
generate_steps
(
src_desc
,
1
);
const
auto
src_backward_steps
=
generate_steps
(
src_desc
,
-
1
);
//
l
oop over t
ensor
and copy
static_ford
<
decltype
(
ordered_
dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
//
L
oop over t
he destination block
and copy
data.
static_ford
<
decltype
(
dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
const
auto
src_offset
=
src_coord_
.
GetOffset
();
const
auto
dst_offset
=
dst_coord_
.
GetOffset
();
// Check if src data is not in the logic padding area.
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
...
...
@@ -145,11 +134,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
ordered_
dst_access_lengths
[
i
]
-
1
;
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
dst_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_dst_access_idx
[
j
]
==
ordered_dst_access_lengths
[
j
]
-
1
;
move_on_dim_
(
i
)
&=
ordered_dst_access_idx
[
j
]
==
dst_access_lengths
[
j
]
-
1
;
});
});
...
...
@@ -157,7 +145,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
}
();
//
judge
move forward or
move
backward
//
Decide whether to
move forward or backward
.
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
...
...
@@ -167,7 +155,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_
dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
tmp
=
tmp
*
dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
...
...
@@ -181,33 +169,26 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dst_dim_access_order
[
i
]]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
dst_dim_access_order
[
i
]]);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
i
]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
i
]);
}
else
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dst_dim_access_order
[
i
]]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
dst_dim_access_order
[
i
]]);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
i
]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
i
]);
}
}
});
});
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow
(
dst_desc
);
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
src_slice_origin_
=
src_slice_origin_
+
step
;
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_
);
}
src_slice_origin_
=
src_slice_origin_
+
step
;
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_
);
}
template
<
typename
DescType
>
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
View file @
5f4c1ddb
...
...
@@ -191,7 +191,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
...
...
@@ -206,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
false
;
}
//
c
heck vector load/store
//
C
heck vector load/store
.
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
//
c
heck vector load of A
//
C
heck vector load of A
.
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferScalarPerVector
!=
0
)
...
...
@@ -221,7 +220,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
return
false
;
...
...
@@ -232,7 +230,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
false
;
}
//
c
heck vector load of B
//
C
heck vector load of B
.
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferScalarPerVector
!=
0
)
...
...
@@ -242,7 +240,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
return
false
;
...
...
@@ -253,8 +250,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
false
;
}
//
c
heck vector load of Ds
//
only support RowMajor for now
//
C
heck vector load of Ds
.
//
For now, only the RowMajor layout is supported.
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
...
@@ -271,8 +268,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
false
;
}
//
c
heck vector
store
of E
//
only support RowMajor for now
//
C
heck vector
load
of E
.
//
For now, only the RowMajor layout is supported.
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
...
...
@@ -293,7 +290,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
...
...
@@ -332,7 +328,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
...
...
@@ -365,13 +360,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
View file @
5f4c1ddb
...
...
@@ -118,7 +118,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
...
...
@@ -186,7 +185,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
...
...
@@ -201,12 +199,12 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return
false
;
}
//
c
heck vector load/store
//
C
heck vector load/store
.
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
//
c
heck vector load of A
//
C
heck vector load of A
.
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferScalarPerVector
!=
0
)
...
...
@@ -216,7 +214,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
return
false
;
...
...
@@ -227,7 +224,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return
false
;
}
//
c
heck vector load of B
//
C
heck vector load of B
.
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferScalarPerVector
!=
0
)
...
...
@@ -237,7 +234,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
return
false
;
...
...
@@ -248,8 +244,8 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return
false
;
}
//
c
heck vector
store
of E
//
only support RowMajor for now
//
C
heck vector
load
of E
.
//
For now, only the RowMajor layout is supported.
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
...
...
@@ -270,7 +266,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
...
...
@@ -310,7 +305,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
...
...
@@ -344,13 +338,11 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
View file @
5f4c1ddb
...
...
@@ -55,8 +55,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90
8
__) || defined(__gfx90
a
__) || \
defined(__gfx940__) ||
defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90
a
__) || defined(__gfx9
4
0__) || \
defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -173,7 +173,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, d
e
st
ination
of blockwise copy
.
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
...
...
@@ -181,7 +181,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, d
e
st
ination
of blockwise copy
.
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
...
...
@@ -217,11 +217,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
.
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
...
...
@@ -230,7 +229,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle
in LDS
// LDS allocation for C shuffle
.
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
...
...
@@ -316,11 +315,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return
MakeEGridDescriptor_M_N
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
},
[
&
](
auto
i
)
{
return
MakeEGridDescriptor_M_N
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
},
Number
<
NumDTensor
>
{});
}
...
...
@@ -329,7 +324,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
1
,
1
,
1
));
// A desc for source in blockwise copy
// A desc for source in blockwise copy
.
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
...
...
@@ -345,7 +340,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
// B desc for source in blockwise copy
.
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
...
...
@@ -361,7 +356,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// E desc for destination in blockwise copy
// E desc for destination in blockwise copy
.
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
...
...
@@ -381,7 +376,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// Ds desc for source in blockwise copy
// Ds desc for source in blockwise copy
.
__host__
__device__
static
constexpr
auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
)
{
...
...
@@ -392,7 +387,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
Number
<
NumDTensor
>
{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
...
...
@@ -411,10 +405,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
...
...
@@ -439,7 +431,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
const
auto
AK
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
BK
=
b_grid_desc_n_k
.
GetLength
(
I1
);
//
c
heck consistency of desc
//
C
heck
the
consistency of desc
riptors.
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
AK
==
BK
))
{
return
false
;
...
...
@@ -457,28 +449,26 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
false
;
}
//
c
heck tile size
//
C
heck
the
tile size
.
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
AK
%
KPerBlock
==
0
))
{
return
false
;
}
//
c
heck gridwise gemm pipeline
//
C
heck gridwise gemm pipeline
.
const
auto
num_k_loop
=
AK
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
//
c
heck block-to-E-tile
//
C
heck block-to-E-tile
.
if
(
!
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
// Check tensor size: cannot exceed 2GB.
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_m_k
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
<=
TwoGB
&&
...
...
@@ -522,7 +512,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
&
block_2_etile_map
)
{
// elementwise operations are not supported for A and B, left only for the API consistency
// Elementwise operations are not supported for A and B, arguments left only for the API
// consistency.
(
void
)
a_element_op
;
(
void
)
b_element_op
;
...
...
@@ -543,7 +534,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
auto
e_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
//
d
ivide block work by [M, N]
//
D
ivide block work by [M, N]
.
const
auto
block_work_idx
=
block_2_etile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
@@ -555,7 +546,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
;
}
//
HACK: t
his forces m/n_block_data_idx_on_grid into SGPR
//
T
his forces m/n_block_data_idx_on_grid into SGPR
.
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
@@ -564,13 +555,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, d
e
st
ination
of blockwise copy
.
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, d
e
st
ination
of blockwise copy
.
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_DirectLoad
<
ThisThreadBlock
,
Sequence
<
AK0PerBlock
,
MPerBlock
,
AK1
>
,
...
...
@@ -588,7 +578,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_DirectLoad
<
ThisThreadBlock
,
Sequence
<
BK0PerBlock
,
NPerBlock
,
BK1
>
,
...
...
@@ -612,7 +601,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
AComputeDataType
,
MPerXdl
,
NPerXdl
,
BComputeDataType
>::
selected_mfma
...
...
@@ -634,7 +622,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
.
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
...
...
@@ -648,7 +636,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
...
...
@@ -672,7 +659,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
c_thread_buf
,
num_k_block_main_loop
);
//
s
huffle C and write out
//
S
huffle C and write out
.
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
...
...
@@ -723,8 +710,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// Calculate the origin of thread output tensor on global memory.
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
...
...
@@ -751,7 +737,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
//
s
huffle: threadwise copy C from VGPR to LDS
//
S
huffle: threadwise copy C from VGPR to LDS
.
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
...
...
@@ -783,7 +769,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// tuple of reference to C/Ds tensor descriptors
//
A
tuple of reference to C/Ds tensor descriptors
.
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
(
...
...
@@ -791,7 +777,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of reference to C/Ds
tensor descripto
rs
//
A
tuple of reference to C/Ds
grid buffe
rs
.
const
auto
c_ds_buf_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_buf
),
generate_tie
(
...
...
@@ -799,7 +785,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{
return
ds_grid_buf
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
//
A
tuple of starting index of C/Ds blockwise copy
.
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
...
...
@@ -808,7 +794,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
},
Number
<
NumDTensor
>
{}));
//
b
lockwise copy C/D/E between LDS and global
//
B
lockwise copy C/D/E between LDS and global
.
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
...
...
@@ -816,8 +802,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDEElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
...
...
@@ -838,7 +823,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
make_tuple
(
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
cde_element_op
};
//
s
pace filling curve for threadwise C in VGPR before shuffle
//
S
pace filling curve for threadwise C in VGPR before shuffle
.
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
...
...
@@ -851,7 +836,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
M4
,
1
>>
{};
//
s
pace filling curve for shuffled blockwise C/D/E
//
S
pace filling curve for shuffled blockwise C/D/E
.
constexpr
auto
sfc_cde_block
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
...
...
@@ -865,20 +850,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
//
m
ake sure it's safe to write to LDS
//
M
ake sure it's safe to write to LDS
.
block_sync_lds
();
//
e
ach thread write its data from VGPR to LDS
//
E
ach thread write its data from VGPR to LDS
.
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
//
m
ake sure it's safe to read from LDS
//
M
ake sure it's safe to read from LDS
.
block_sync_lds
();
//
e
ach block copy its data from LDS to global
//
E
ach block copy its data from LDS to global
.
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
...
...
@@ -890,13 +875,13 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
//
m
ove on Ds
//
M
ove on Ds
.
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
//
m
ove on E
//
M
ove on E
.
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
...
...
@@ -942,19 +927,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
NRaw_
{
NRaw
},
KRaw_
{
KRaw
}
{
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
ds_grid_desc_m_n_
(
i
)
=
MakeEGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
});
// populate desc for Ds/E
if
(
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
...
...
@@ -978,19 +956,19 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
std
::
cout
<<
"E[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
}
//
p
ointers
//
P
ointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
//
t
ensor descriptors for problem definiton
//
T
ensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
//
t
ensor descriptors for block/thread-wise copy
//
T
ensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -1000,12 +978,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
// element-wise op
s
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
//
f
or checking vector load/store
//
F
or checking vector load/store
index_t
MRaw_
;
index_t
NRaw_
;
index_t
KRaw_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment