Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
91075f0f
Commit
91075f0f
authored
Jul 26, 2023
by
Jing Zhang
Browse files
clean deviceop
parent
c0264b8f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
82 deletions
+31
-82
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+29
-82
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+2
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
91075f0f
...
@@ -317,6 +317,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -317,6 +317,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
>
;
#if 0
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
...
@@ -325,6 +326,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -325,6 +326,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
#endif
template
<
typename
UnderlyingBlockToCTileMap
>
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMapMLoops
struct
OffsettedBlockToCTileMapMLoops
...
@@ -483,6 +485,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -483,6 +485,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
;
index_t
StrideE_
;
index_t
StrideE_
;
#if 0
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
BGridDesc_N_K b_grid_desc_n_k_;
...
@@ -498,6 +501,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -498,6 +501,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// block-to-e-tile map
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
Block2ETileMap block_2_etile_map_;
#endif
};
};
// Argument
// Argument
...
@@ -591,12 +595,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -591,12 +595,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
M
,
N
,
gemm_descs
[
i
].
stride_Ds_
[
j
]);
M
,
N
,
gemm_descs
[
i
].
stride_Ds_
[
j
]);
});
});
#if 0
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 =
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 =
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
#endif
const
auto
e_grid_desc_m_n
=
const
auto
e_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideC
);
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideC
);
...
@@ -604,7 +610,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -604,7 +610,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// block-to-e-tile map
// block-to-e-tile map
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
};
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
if
(
group_id
*
grid_size_grp
!=
grid_size_
)
if
(
group_id
*
grid_size_grp
!=
grid_size_
)
{
{
...
@@ -619,41 +625,24 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -619,41 +625,24 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
e_grid_desc_m_n
,
e_grid_desc_m_n
,
local_b2c_tile_map
))
local_b2c_tile_map
))
{
{
// tensor descriptors for block/thread-wise copy
gemm_desc_kernel_arg_
.
push_back
(
GemmBiasTransKernelArg
{
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
p_As
.
size
()
==
0
?
nullptr
:
p_As
[
i
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
p_Bs
.
size
()
==
0
?
nullptr
:
p_Bs
[
i
],
p_ds_grid
,
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
p_Es
[
i
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
(
j
)
=
M
,
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
N
,
ds_grid_desc_m_n
[
j
]);
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideC
,
});
});
}
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
else
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
{
e_grid_desc_m_n
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
gemm_desc_kernel_arg_
.
push_back
(
GemmBiasTransKernelArg
{
p_As
.
size
()
==
0
?
nullptr
:
p_As
[
i
],
p_Bs
.
size
()
==
0
?
nullptr
:
p_Bs
[
i
],
p_ds_grid
,
p_Es
[
i
],
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideC
,
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
local_b2c_tile_map
});
}
}
group_id
++
;
group_id
++
;
...
@@ -674,6 +663,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -674,6 +663,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const
void
*
grouped_gemm_kernel_args_dev
;
const
void
*
grouped_gemm_kernel_args_dev
;
index_t
grid_size_
;
index_t
grid_size_
;
index_t
grid_size_grp
;
};
};
// Invoker
// Invoker
...
@@ -691,51 +681,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -691,51 +681,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
{
#if DEBUG_LOG
const
auto
KPad
=
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_ak0_m_ak1_{"
GridwiseGemm
::
CalculateKPadded
(
arg
.
gemm_desc_kernel_arg_
[
i
].
K_
,
1
);
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.b_grid_desc_bk0_n_bk1_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
KPad
)
!=
has_main_k_block_loop
)
std
::
cout
<<
", arg.d"
<<
i
<<
"_grid_desc_m_n_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
ds_grid_desc_m_n_
[
j
].
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
ds_grid_desc_m_n_
[
j
].
GetLength
(
I1
)
<<
"}"
;
});
std
::
cout
<<
", arg.e_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_m_k_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_n_k_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
ds_grid_desc_m_n_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
const
auto
K
=
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
)
!=
has_main_k_block_loop
)
{
{
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
}
}
...
@@ -773,8 +722,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -773,8 +722,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
>
;
const
index_t
grid_size_grp
=
arg
.
grid_size_
/
arg
.
group_count_
;
const
void
*
kernel_args_dev
=
nullptr
;
const
void
*
kernel_args_dev
=
nullptr
;
if
(
arg
.
grouped_gemm_kernel_args_dev
!=
nullptr
)
if
(
arg
.
grouped_gemm_kernel_args_dev
!=
nullptr
)
...
@@ -817,7 +764,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -817,7 +764,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
0
,
0
,
cast_pointer_to_constant_address_space
(
kernel_args_dev
),
cast_pointer_to_constant_address_space
(
kernel_args_dev
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
grid_size_grp
,
arg
.
grid_size_grp
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
91075f0f
...
@@ -200,6 +200,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -200,6 +200,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
c_block_size
*
sizeof
(
CShuffleDataType
));
c_block_size
*
sizeof
(
CShuffleDataType
));
}
}
#if 0
// A desc for source in blockwise copy
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
__host__ __device__ static constexpr auto
...
@@ -233,6 +234,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -233,6 +234,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
#endif
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment