Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e542dfc4
Commit
e542dfc4
authored
Jun 16, 2023
by
Jing Zhang
Committed by
root
Jun 16, 2023
Browse files
test
parent
8bfacf9f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
162 additions
and
42 deletions
+162
-42
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+162
-42
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
e542dfc4
...
@@ -23,7 +23,7 @@ namespace ck {
...
@@ -23,7 +23,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
#if
1
#if
0
template <typename GridwiseGemm,
template <typename GridwiseGemm,
typename GemmDesc,
typename GemmDesc,
typename GemmSharedArgs,
typename GemmSharedArgs,
...
@@ -114,7 +114,78 @@ __global__ void
...
@@ -114,7 +114,78 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
#elif 1
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_size
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
for
(
index_t
group_id
=
0
;
group_id
<
group_size
;
group_id
++
)
{
const
auto
M
=
gemm_desc_ptr
[
group_id
].
karg_
.
M
;
const
auto
N
=
gemm_desc_ptr
[
group_id
].
karg_
.
N
;
const
auto
K
=
gemm_desc_ptr
[
group_id
].
karg_
.
K
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideB
;
const
auto
StrideC
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideC
;
const
auto
MPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
MPadded
;
const
auto
NPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
NPadded
;
const
auto
KPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
KPadded
;
const
auto
K0
=
gemm_desc_ptr
[
group_id
].
karg_
.
K0
;
const
auto
k_batch
=
gemm_desc_ptr
[
group_id
].
karg_
.
k_batch
;
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
static
constexpr
index_t
B2E_M01
=
8
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMapKSplit
>
;
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
const
auto
block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
0
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
.
p_a_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_b_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_c_grid
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
,
KPadded
,
K0
,
k_batch
,
static_cast
<
void
*>
(
p_shared
),
block_2_ctile_map
);
}
#else
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#elif 0
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
typename
GemmDesc
,
...
@@ -125,7 +196,15 @@ __global__ void
...
@@ -125,7 +196,15 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
// const index_t group_count
#if 0
const index_t N,
const index_t K,
const index_t StrideB,
const index_t NPadded,
const index_t KPadded,
const index_t K0,
const index_t k_batch,
#endif
const
index_t
block_size
)
const
index_t
block_size
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
...
@@ -159,27 +238,55 @@ __global__ void
...
@@ -159,27 +238,55 @@ __global__ void
const
index_t
group_id
=
block_id
/
block_size
;
const
index_t
group_id
=
block_id
/
block_size
;
#endif
#endif
#if 0
#if 1
const auto N = gemm_desc_ptr[0].karg_.N;
const auto K = gemm_desc_ptr[0].karg_.K;
#if 1
const auto StrideB = gemm_desc_ptr[0].karg_.StrideB;
const
auto
M
=
gemm_desc_ptr
[
group_id
].
karg_
.
M
;
const auto NPadded = gemm_desc_ptr[0].karg_.NPadded;
const
auto
N
=
gemm_desc_ptr
[
group_id
].
karg_
.
N
;
const auto KPadded = gemm_desc_ptr[0].karg_.KPadded;
const
auto
K
=
gemm_desc_ptr
[
group_id
].
karg_
.
K
;
const auto K0 = gemm_desc_ptr[0].karg_.KPadded;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideA
;
const auto k_batch = gemm_desc_ptr[0].karg_.k_batch;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideB
;
const auto block_2_ctile_map = gemm_desc_ptr[0].block_2_ctile_map_;
const
auto
StrideC
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideC
;
const
auto
MPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
MPadded
;
const
auto
NPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
NPadded
;
const
auto
KPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
KPadded
;
const
auto
K0
=
gemm_desc_ptr
[
group_id
].
karg_
.
K0
;
const
auto
k_batch
=
gemm_desc_ptr
[
group_id
].
karg_
.
k_batch
;
#else
const
auto
M
=
gemm_desc_ptr
[
group_id
].
karg_
.
M
;
const
auto
MPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
MPadded
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideA
;
const
auto
StrideC
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideC
;
#endif
// const auto block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_;
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
static
constexpr
index_t
B2E_M01
=
8
;
const
index_t
block_start
=
block_size
*
group_id
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMapKSplit
>
;
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
const
auto
block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
.
p_a_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_a_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_b_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_b_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_c_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_c_grid
,
gemm_desc_ptr[group_id].karg_.
M,
M
,
N
,
N
,
K
,
K
,
gemm_desc_ptr[group_id].karg_.
StrideA,
StrideA
,
StrideB
,
StrideB
,
gemm_desc_ptr[group_id].karg_.
StrideC,
StrideC
,
gemm_desc_ptr[group_id].karg_.
MPadded,
MPadded
,
NPadded
,
NPadded
,
KPadded
,
KPadded
,
K0
,
K0
,
...
@@ -325,16 +432,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -325,16 +432,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct
GemmTransKernelArg
struct
GemmTransKernelArg
{
{
KernelArgument
karg_
;
KernelArgument
karg_
;
GroupedGemmBlock2ETileMap
block_2_ctile_map_
;
//
GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t
block_start_
,
block_end_
;
index_t
block_start_
,
block_end_
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
KernelArgument
&&
karg
,
GemmTransKernelArg
(
KernelArgument
&&
karg
,
GroupedGemmBlock2ETileMap
&&
b2c_map
,
//
GroupedGemmBlock2ETileMap&& b2c_map,
index_t
block_start
,
index_t
block_start
,
index_t
block_end
)
index_t
block_end
)
:
karg_
{
karg
},
:
karg_
{
karg
},
block_2_ctile_map_
{
b2c_map
},
//
block_2_ctile_map_{b2c_map},
block_start_
{
block_start
},
block_start_
{
block_start
},
block_end_
{
block_end
}
block_end_
{
block_end
}
{
{
...
@@ -404,10 +511,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -404,10 +511,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
block_start
=
grid_size_
;
const
index_t
block_start
=
0
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
const
index_t
block_end
=
grid_size_grp
;
grid_size_
+
=
grid_size_grp
;
grid_size_
=
grid_size_grp
;
// block-to-e-tile map
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
auto
grouped_block_2_ctile_map
=
...
@@ -428,8 +535,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -428,8 +535,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
k0
,
k0
,
K_BATCH
};
K_BATCH
};
gemm_kernel_args_
.
emplace_back
(
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
std
::
move
(
karg
),
std
::
move
(
grouped_block_2_ctile_map
),
block_start
,
block_end
);
// std::move(grouped_block_2_ctile_map),
block_start
,
block_end
);
}
}
}
}
...
@@ -458,21 +567,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -458,21 +567,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
block_start
=
grid_size_
;
const
index_t
block_start
=
0
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
const
index_t
block_end
=
grid_size_grp
;
grid_size_
+
=
grid_size_grp
;
grid_size_
=
grid_size_grp
;
// block-to-e-tile map
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
karg
.
KPadded
=
k_padded
;
karg
.
KPadded
=
k_padded
;
karg
.
K0
=
k0
;
karg
.
K0
=
k0
;
karg
.
k_batch
=
K_BATCH
;
karg
.
k_batch
=
K_BATCH
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
grouped_block_2_ctile_map
;
//
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
}
}
}
}
...
@@ -556,24 +665,24 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -556,24 +665,24 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
#if 1
#if 1
std
::
vector
<
GemmTransKernelArgMsN1K1
>
gemm_kernel_args_msn1k1_
;
std
::
vector
<
GemmTransKernelArgMsN1K1
>
gemm_kernel_args_msn1k1_
;
index_t
all_gemm_block_size
=
//
index_t all_gemm_block_size =
arg
.
gemm_kernel_args_
[
0
].
block_end_
-
arg
.
gemm_kernel_args_
[
0
].
block_start_
;
//
arg.gemm_kernel_args_[0].block_end_ - arg.gemm_kernel_args_[0].block_start_;
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
{
{
auto
karg
=
ArgumentMsN1K1
{
auto
karg
=
ArgumentMsN1K1
{
trans_arg
.
karg_
.
p_a_grid
,
trans_arg
.
karg_
.
p_b_grid
,
trans_arg
.
karg_
.
p_c_grid
};
trans_arg
.
karg_
.
p_a_grid
,
trans_arg
.
karg_
.
p_b_grid
,
trans_arg
.
karg_
.
p_c_grid
};
//
auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
auto
block_size
=
trans_arg
.
block_end_
-
trans_arg
.
block_start_
;
//
std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
std
::
cout
<<
"trans_arg.block_start_: "
<<
trans_arg
.
block_start_
//
<< " trans_arg.block_end_: " << trans_arg.block_end_
<<
" trans_arg.block_end_: "
<<
trans_arg
.
block_end_
//
<< " block_size: " << block_size << std::endl;
<<
" block_size: "
<<
block_size
<<
std
::
endl
;
gemm_kernel_args_msn1k1_
.
push_back
({
karg
});
gemm_kernel_args_msn1k1_
.
push_back
({
karg
});
}
}
#endif
#endif
#if
0
#if
1
hip_check_error
(
hipMemcpy
(
arg
.
p_workspace_
,
hip_check_error
(
hipMemcpy
(
arg
.
p_workspace_
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
),
...
@@ -653,14 +762,25 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -653,14 +762,25 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
shared_karg
#if 0
// all_gemm_block_size
arg.gemm_kernel_args_[0].karg_.N,
arg.gemm_kernel_args_[0].karg_.K,
arg.gemm_kernel_args_[0].karg_.StrideB,
arg.gemm_kernel_args_[0].karg_.NPadded,
arg.gemm_kernel_args_[0].karg_.KPadded,
arg.gemm_kernel_args_[0].karg_.K0,
arg.gemm_kernel_args_[0].karg_.k_batch,
#elif
0
all_gemm_block_size
#elif 1
arg
.
gemm_kernel_args_
.
size
()
#endif
);
);
};
};
std
::
cout
<<
"all_have_main_k0_block_loop: "
<<
all_have_main_k0_block_loop
std
::
cout
<<
"all_have_main_k0_block_loop: "
<<
all_have_main_k0_block_loop
<<
" all_have_kbatch_gt_one: "
<<
all_have_kbatch_gt_one
<<
std
::
endl
;
<<
" all_have_kbatch_gt_one: "
<<
all_have_kbatch_gt_one
<<
std
::
endl
;
#if
0
#if
1
if
(
all_have_main_k0_block_loop
)
if
(
all_have_main_k0_block_loop
)
{
{
if
(
all_have_kbatch_gt_one
)
if
(
all_have_kbatch_gt_one
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment