Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
326d6bc6
Commit
326d6bc6
authored
Jul 17, 2023
by
Jing Zhang
Browse files
move all arguments into device
parent
03d3395b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
315 additions
and
66 deletions
+315
-66
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+38
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+87
-60
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+183
-0
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+4
-2
No files found.
cmake/EnableCompilerWarnings.cmake
View file @
326d6bc6
...
...
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
#
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
326d6bc6
...
...
@@ -46,7 +46,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm_Xdl
// clang-format off
...
...
@@ -59,4 +59,40 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
#include "run_grouped_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_gemm_example
(
argc
,
argv
);
}
// int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int
main
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
problem_size
.
Ns
.
push_back
(
768
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
}
if
(
argc
==
4
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
exit
(
0
);
}
return
!
run_grouped_gemm
(
problem_size
,
config
);
}
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
326d6bc6
...
...
@@ -24,6 +24,8 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
GemmSpecialization
GemmSpec
,
typename
Block2ETileMapKSplit
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
...
...
@@ -34,6 +36,7 @@ __global__ void
#endif
kernel_grouped_gemm_xdl
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
,
const
index_t
grid_size_grp
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
...
...
@@ -47,6 +50,7 @@ __global__ void
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
#if 0
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
...
...
@@ -64,7 +68,14 @@ __global__ void
}
group_id = index_t((left + right) / 2);
}
#endif
const
index_t
group_id
=
block_id
/
grid_size_grp
;
if
(
group_id
>=
group_count
)
return
;
#if 0
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
...
...
@@ -79,6 +90,56 @@ __global__ void
gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].block_2_etile_map_);
#else
const
index_t
M
=
gemm_desc_ptr
[
group_id
].
M
;
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
const
index_t
StrideA
=
K
;
const
index_t
StrideB
=
K
;
const
index_t
StrideDs
[]
=
{};
const
index_t
StrideE
=
N
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
const
auto
e_grid_desc_m_n
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
const
index_t
BlockStart
=
group_id
*
grid_size_grp
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMapKSplit
>
;
const
auto
local_b2e_tile_map
=
Block2ETileMapKSplit
{
e_grid_desc_m_n
};
const
auto
block_2_etile_map
=
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
gemm_desc_ptr
[
group_id
].
a_ptr_
,
gemm_desc_ptr
[
group_id
].
b_ptr_
,
gemm_desc_ptr
[
group_id
].
ds_ptr_
,
gemm_desc_ptr
[
group_id
].
e_ptr_
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
block_2_etile_map
);
#endif
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
...
...
@@ -281,46 +342,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
struct
GroupedGemmBlock2ETileMap
{
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
GroupedGemmBlock2ETileMap
()
{
block_2_etile_map_
=
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{});
BlockStart_
=
-
1
;
}
GroupedGemmBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
ck
::
index_t
BlockStart
)
{
block_2_etile_map_
=
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
BlockStart_
=
BlockStart
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_2_etile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
I0
]
-
BlockStart_
));
}
// it's actually E-Tile
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_2_etile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
__host__
bool
CheckValidity
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
const
{
return
block_2_etile_map_
.
CheckValidity
(
e_grid_desc_m_n
);
}
Block2ETileMap
block_2_etile_map_
;
ck
::
index_t
BlockStart_
;
};
using
Block2ETileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMap
>
;
struct
GemmBiasTransKernelArg
{
...
...
@@ -330,6 +353,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
typename
GridwiseGemm
::
DsGridPointer
ds_ptr_
;
EDataType
*
e_ptr_
;
index_t
M
,
N
,
K
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
...
...
@@ -344,7 +369,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
GroupedGemm
Block2ETileMap
block_2_etile_map_
;
Block2ETileMap
block_2_etile_map_
;
ck
::
index_t
BlockStart_
,
BlockEnd_
;
};
...
...
@@ -374,7 +399,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
gemm_desc_kernel_arg_
.
reserve
(
group_count_
);
skipped_group_count_
=
0
;
index_t
group_id
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
...
...
@@ -385,12 +410,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
a_mtx_mraw_kraw_
.
emplace_back
(
M
,
K
);
b_mtx_nraw_kraw_
.
emplace_back
(
N
,
K
);
if
(
M
==
0
)
{
skipped_group_count_
++
;
continue
;
}
const
index_t
StrideA
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
StrideB
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
StrideC
=
gemm_descs
[
i
].
stride_C_
;
...
...
@@ -427,24 +446,23 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const
auto
b_grid_desc_bk0_n_bk1
=
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
const
index_t
grid_size_grp
=
GroupedGemmBlock2ETileMap
(
e_grid_desc_m_n
,
0
)
.
block_2_etile_map_
.
CalculateGridSize
(
e_grid_desc_m_n
);
// block-to-e-tile map
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
std
::
cout
<<
"grp id: "
<<
group_id
<<
" grid_size: "
<<
grid_size_grp
<<
std
::
endl
;
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// block-to-e-tile map
const
auto
block_2_etile_map
=
GroupedGemmBlock2ETileMap
(
e_grid_desc_m_n
,
BlockStart
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
b
loc
k_2_e
tile_map
))
loc
al_b2c_
tile_map
))
{
// tensor descriptors for block/thread-wise copy
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -465,6 +483,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
static_cast
<
const
BDataType
*>
(
p_Bs
[
i
]),
p_ds_grid
,
static_cast
<
EDataType
*>
(
p_Es
[
i
]),
M
,
N
,
K
,
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
...
...
@@ -473,16 +494,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
b
loc
k_2_e
tile_map
,
loc
al_b2c_
tile_map
,
BlockStart
,
BlockEnd
});
}
group_id
++
;
}
}
// private:
index_t
group_count_
;
index_t
skipped_group_count_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
...
...
@@ -560,11 +582,16 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl
<
GridwiseGemm
,
GemmBiasTransKernelArg
,
GemmSpec
,
Block2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
has_main_k_block_loop_
>
;
const
index_t
grid_size_grp
=
arg
.
gemm_desc_kernel_arg_
[
0
].
BlockEnd_
-
arg
.
gemm_desc_kernel_arg_
[
0
].
BlockStart_
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -573,6 +600,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
grid_size_grp
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
...
...
@@ -600,8 +628,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_desc_kernel_arg_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
if
(
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_desc_kernel_arg_
.
size
())
!=
arg
.
group_count_
)
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
326d6bc6
...
...
@@ -585,7 +585,8 @@ struct OffsettedBlockToCTileMap
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
OffsettedBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
block_start
)
__host__
__device__
OffsettedBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
block_start
)
{
block_to_ctile_map_
=
block_to_ctile_map
;
block_start_
=
block_start
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
326d6bc6
...
...
@@ -15,6 +15,9 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace
ck
{
// GEMM:
...
...
@@ -331,6 +334,97 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
template
<
typename
ALayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
template
<
typename
BLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeBGridDescriptor_N_K
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
template
<
typename
ELayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
{
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
const
auto
e_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideE
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideE
));
}
}();
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
template
<
typename
DsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
},
Number
<
NumDTensor
>
{});
}
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
...
...
@@ -760,6 +854,95 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
});
}
}
template
<
bool
HasMainKBlockLoop
,
GemmSpecialization
GemmSpec
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
#if 0
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
#endif
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
index_t
StrideDs
[],
const
index_t
StrideE
,
#if 0
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
#endif
const
Block2ETileMap
&
block_2_etile_map
)
{
// tensor descriptors for problem definiton
const
auto
a_grid_desc_m_k
=
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>
(
M
,
K
,
StrideA
);
const
auto
b_grid_desc_n_k
=
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>
(
K
,
N
,
StrideB
);
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>
({},
{},
{}))
>
;
DsGridDesc_M_N
ds_grid_desc_m_n
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
j
)
=
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>
(
M
,
N
,
StrideDs
[
j
]);
});
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>
(
M
,
N
,
StrideE
);
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_ak0_m_ak1
=
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock
(
j
)
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
j
]);
});
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
};
}
// namespace ck
script/cmake-ck-dev.sh
View file @
326d6bc6
...
...
@@ -11,9 +11,11 @@ cmake
-D
CMAKE_CXX_FLAGS
=
"-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker
\
-save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
O
N
\
-D
GPU_TARGETS
=
"gfx90
8;gfx90a;gfx940
"
\
-D
BUILD_DEV
=
O
FF
\
-D
GPU_TARGETS
=
"gfx90
a
"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
#-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment