Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5f4c1ddb
Commit
5f4c1ddb
authored
Nov 01, 2023
by
Bartlomiej Wroblewski
Browse files
Clean the code and comments
parent
0661e8d2
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
80 additions
and
152 deletions
+80
-152
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+2
-4
example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp
example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp
+0
-14
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp
.../block/thread_group_tensor_slice_transfer_direct_load.hpp
+20
-39
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
+7
-14
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
.../device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
+5
-13
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
...gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
+46
-68
No files found.
example/01_gemm/CMakeLists.txt
View file @
5f4c1ddb
...
@@ -44,7 +44,7 @@ if(USE_BITINT_EXTENSION_INT4)
...
@@ -44,7 +44,7 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
# FIXME: re-enable this examp
l
e as test when SWDEV-335738 is fixed
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
...
@@ -58,9 +58,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
...
@@ -58,9 +58,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
if
(
GPU_TARGETS MATCHES
"gfx90a"
)
if
(
GPU_TARGETS MATCHES
"gfx90a"
)
add_example_executable
(
example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp
)
add_example_executable
(
example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32
)
endif
()
endif
()
endif
()
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
...
...
example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp
View file @
5f4c1ddb
...
@@ -6,26 +6,12 @@
...
@@ -6,26 +6,12 @@
#include "common.hpp"
#include "common.hpp"
#define USING_DIRECT_LOADS 1
#define USING_DIRECT_LOADS 1
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#if USING_DIRECT_LOADS
#if USING_DIRECT_LOADS
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#else
#else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#endif
#endif
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
using
F32
=
float
;
using
F32
=
float
;
using
ADataType
=
F32
;
using
ADataType
=
F32
;
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp
View file @
5f4c1ddb
...
@@ -67,17 +67,13 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
...
@@ -67,17 +67,13 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"The number of threads cannot be less than the number of elements in "
"The number of threads cannot be less than the number of elements in "
"thread cluster lengths."
);
"thread cluster lengths."
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
const
auto
thread_cluster_idx
=
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
;
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
@@ -103,11 +99,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
...
@@ -103,11 +99,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
DstBuffer
&
dst_buf
)
{
{
if
(
ThreadGroup
::
GetNumOfThread
()
!=
thread_cluster_desc_
.
GetElementSize
()
&&
ThreadGroup
::
GetThreadId
()
>=
thread_cluster_desc_
.
GetElementSize
())
{
return
;
}
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
"Source data must come from a global memory buffer."
);
"Source data must come from a global memory buffer."
);
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
...
@@ -120,21 +111,19 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
...
@@ -120,21 +111,19 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"DstBuffer and DstData data types must be consistent."
);
"DstBuffer and DstData data types must be consistent."
);
constexpr
auto
dst_access_lengths
=
thread_slice_lengths
;
constexpr
auto
dst_access_lengths
=
thread_slice_lengths
;
constexpr
auto
dst_dim_access_order
=
Sequence
<
0
,
1
,
2
>
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
const
auto
dst_forward_steps
=
generate_steps
(
dst_desc
,
1
);
const
auto
dst_forward_steps
=
generate_steps
(
dst_desc
,
1
);
const
auto
dst_backward_steps
=
generate_steps
(
dst_desc
,
-
1
);
const
auto
dst_backward_steps
=
generate_steps
(
dst_desc
,
-
1
);
const
auto
src_forward_steps
=
generate_steps
(
src_desc
,
1
);
const
auto
src_forward_steps
=
generate_steps
(
src_desc
,
1
);
const
auto
src_backward_steps
=
generate_steps
(
src_desc
,
-
1
);
const
auto
src_backward_steps
=
generate_steps
(
src_desc
,
-
1
);
//
l
oop over t
ensor
and copy
//
L
oop over t
he destination block
and copy
data.
static_ford
<
decltype
(
ordered_
dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
static_ford
<
decltype
(
dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
const
auto
src_offset
=
src_coord_
.
GetOffset
();
const
auto
src_offset
=
src_coord_
.
GetOffset
();
const
auto
dst_offset
=
dst_coord_
.
GetOffset
();
const
auto
dst_offset
=
dst_coord_
.
GetOffset
();
// Check if src data is not in the logic padding area.
const
bool
is_src_valid
=
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
...
@@ -145,11 +134,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
...
@@ -145,11 +134,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
ordered_
dst_access_lengths
[
i
]
-
1
;
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
dst_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
move_on_dim_
(
i
)
&=
ordered_dst_access_idx
[
j
]
==
dst_access_lengths
[
j
]
-
1
;
ordered_dst_access_idx
[
j
]
==
ordered_dst_access_lengths
[
j
]
-
1
;
});
});
});
});
...
@@ -157,7 +145,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
...
@@ -157,7 +145,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
}
}
();
();
//
judge
move forward or
move
backward
//
Decide whether to
move forward or backward
.
constexpr
auto
forward_sweep
=
[
&
]()
{
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
...
@@ -167,7 +155,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
...
@@ -167,7 +155,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_
dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
tmp
=
tmp
*
dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
...
@@ -181,33 +169,26 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
...
@@ -181,33 +169,26 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
{
{
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_tensor_coordinate
(
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
i
]);
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dst_dim_access_order
[
i
]]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
i
]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
dst_dim_access_order
[
i
]]);
}
}
else
else
{
{
move_tensor_coordinate
(
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
i
]);
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dst_dim_access_order
[
i
]]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
i
]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
dst_dim_access_order
[
i
]]);
}
}
}
}
});
});
});
});
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow
(
dst_desc
);
ResetDstSliceWindow
(
dst_desc
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
src_slice_origin_
=
src_slice_origin_
+
step
;
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_
);
{
src_slice_origin_
=
src_slice_origin_
+
step
;
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_
);
}
}
}
template
<
typename
DescType
>
template
<
typename
DescType
>
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
View file @
5f4c1ddb
...
@@ -191,7 +191,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -191,7 +191,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
}
}
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
...
@@ -206,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -206,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
false
;
return
false
;
}
}
//
c
heck vector load/store
//
C
heck vector load/store
.
{
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
//
c
heck vector load of A
//
C
heck vector load of A
.
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
{
if
(
arg
.
KRaw_
%
ABlockTransferScalarPerVector
!=
0
)
if
(
arg
.
KRaw_
%
ABlockTransferScalarPerVector
!=
0
)
...
@@ -221,7 +220,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -221,7 +220,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferScalarPerVector
!=
0
)
if
(
arg
.
MRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
...
@@ -232,7 +230,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -232,7 +230,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
false
;
return
false
;
}
}
//
c
heck vector load of B
//
C
heck vector load of B
.
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
{
if
(
arg
.
KRaw_
%
BBlockTransferScalarPerVector
!=
0
)
if
(
arg
.
KRaw_
%
BBlockTransferScalarPerVector
!=
0
)
...
@@ -242,7 +240,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -242,7 +240,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferScalarPerVector
!=
0
)
if
(
arg
.
NRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
...
@@ -253,8 +250,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -253,8 +250,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
false
;
return
false
;
}
}
//
c
heck vector load of Ds
//
C
heck vector load of Ds
.
//
only support RowMajor for now
//
For now, only the RowMajor layout is supported.
bool
all_valid
=
true
;
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -271,8 +268,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -271,8 +268,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
return
false
;
return
false
;
}
}
//
c
heck vector
store
of E
//
C
heck vector
load
of E
.
//
only support RowMajor for now
//
For now, only the RowMajor layout is supported.
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
...
@@ -293,7 +290,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -293,7 +290,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
arg
.
block_2_etile_map_
);
arg
.
block_2_etile_map_
);
}
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
...
@@ -332,7 +328,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -332,7 +328,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
...
@@ -365,13 +360,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -365,13 +360,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
cde_element_op
);
cde_element_op
);
}
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
}
// polymorphic
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp
View file @
5f4c1ddb
...
@@ -118,7 +118,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
...
@@ -118,7 +118,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
using
Argument
=
typename
GridwiseGemm
::
Argument
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
...
@@ -186,7 +185,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
...
@@ -186,7 +185,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
}
}
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
{
...
@@ -201,12 +199,12 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
...
@@ -201,12 +199,12 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return
false
;
return
false
;
}
}
//
c
heck vector load/store
//
C
heck vector load/store
.
{
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
//
c
heck vector load of A
//
C
heck vector load of A
.
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
{
if
(
arg
.
KRaw_
%
ABlockTransferScalarPerVector
!=
0
)
if
(
arg
.
KRaw_
%
ABlockTransferScalarPerVector
!=
0
)
...
@@ -216,7 +214,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
...
@@ -216,7 +214,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferScalarPerVector
!=
0
)
if
(
arg
.
MRaw_
%
ABlockTransferScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
...
@@ -227,7 +224,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
...
@@ -227,7 +224,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return
false
;
return
false
;
}
}
//
c
heck vector load of B
//
C
heck vector load of B
.
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
{
if
(
arg
.
KRaw_
%
BBlockTransferScalarPerVector
!=
0
)
if
(
arg
.
KRaw_
%
BBlockTransferScalarPerVector
!=
0
)
...
@@ -237,7 +234,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
...
@@ -237,7 +234,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferScalarPerVector
!=
0
)
if
(
arg
.
NRaw_
%
BBlockTransferScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
...
@@ -248,8 +244,8 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
...
@@ -248,8 +244,8 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
return
false
;
return
false
;
}
}
//
c
heck vector
store
of E
//
C
heck vector
load
of E
.
//
only support RowMajor for now
//
For now, only the RowMajor layout is supported.
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
...
@@ -270,7 +266,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
...
@@ -270,7 +266,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
arg
.
block_2_etile_map_
);
arg
.
block_2_etile_map_
);
}
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
...
@@ -310,7 +305,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
...
@@ -310,7 +305,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
...
@@ -344,13 +338,11 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
...
@@ -344,13 +338,11 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
cde_element_op
);
cde_element_op
);
}
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
}
// polymorphic
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
View file @
5f4c1ddb
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment