Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8bfacf9f
Commit
8bfacf9f
authored
Jun 13, 2023
by
Jing Zhang
Committed by
root
Jun 16, 2023
Browse files
move block2tile into kernel
parent
cf9bcb31
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
11 deletions
+26
-11
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+20
-10
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+4
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
8bfacf9f
...
@@ -60,9 +60,6 @@ __global__ void
...
@@ -60,9 +60,6 @@ __global__ void
// const auto K0 = gemm_shared_args.KPadded;
// const auto K0 = gemm_shared_args.KPadded;
// const auto k_batch = gemm_shared_args.k_batch;
// const auto k_batch = gemm_shared_args.k_batch;
// M = 2 N = 768 K = 4608 StrideA = 4608 StrideB = 4608 StrideC = 768 MPadded = 32 NPadded = 768
// KPadded = 4608 K0 = 576 k_batch = 1
const
auto
M
=
2
;
const
auto
M
=
2
;
const
auto
N
=
768
;
const
auto
N
=
768
;
const
auto
K
=
4608
;
const
auto
K
=
4608
;
...
@@ -75,7 +72,22 @@ __global__ void
...
@@ -75,7 +72,22 @@ __global__ void
const
auto
K0
=
576
;
const
auto
K0
=
576
;
const
auto
k_batch
=
1
;
const
auto
k_batch
=
1
;
// const auto block_2_ctile_map = gemm_shared_args.block_2_ctile_map;
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
static
constexpr
index_t
B2E_M01
=
8
;
const
index_t
block_start
=
gemm_shared_args
.
block_size
*
group_id
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMapKSplit
>
;
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
const
auto
block_2_ctile_map
=
grouped_block_2_ctile_map
;
#endif
#endif
...
@@ -95,7 +107,7 @@ __global__ void
...
@@ -95,7 +107,7 @@ __global__ void
K0
,
K0
,
k_batch
,
k_batch
,
static_cast
<
void
*>
(
p_shared
),
static_cast
<
void
*>
(
p_shared
),
gemm_desc_ptr
[
group_id
].
karg_
.
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
gemm_descs_const
;
ignore
=
gemm_descs_const
;
ignore
=
all_gemm_block_size
;
ignore
=
all_gemm_block_size
;
...
@@ -533,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -533,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// index_t StrideA;
// index_t StrideA;
// index_t StrideC;
// index_t StrideC;
// index_t MPadded;
// index_t MPadded;
GroupedGemmBlock2ETileMap
block_2_ctile_map
;
//
GroupedGemmBlock2ETileMap block_2_ctile_map;
};
};
struct
GemmTransKernelArgMsN1K1
struct
GemmTransKernelArgMsN1K1
...
@@ -549,10 +561,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -549,10 +561,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
{
{
auto
karg
=
ArgumentMsN1K1
{
trans_arg
.
karg_
.
p_a_grid
,
auto
karg
=
ArgumentMsN1K1
{
trans_arg
.
karg_
.
p_b_grid
,
trans_arg
.
karg_
.
p_a_grid
,
trans_arg
.
karg_
.
p_b_grid
,
trans_arg
.
karg_
.
p_c_grid
};
trans_arg
.
karg_
.
p_c_grid
,
trans_arg
.
block_2_ctile_map_
};
// auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
// auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
// std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
// std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
8bfacf9f
...
@@ -551,7 +551,8 @@ struct OffsettedBlockToCTileMap
...
@@ -551,7 +551,8 @@ struct OffsettedBlockToCTileMap
{
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
using
underlying_type
=
UnderlyingBlockToCTileMap
;
OffsettedBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
block_start
)
__host__
__device__
OffsettedBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
block_start
)
{
{
block_to_ctile_map_
=
block_to_ctile_map
;
block_to_ctile_map_
=
block_to_ctile_map
;
block_start_
=
block_start
;
block_start_
=
block_start
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
8bfacf9f
...
@@ -1090,6 +1090,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -1090,6 +1090,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
block_2_ctile_map
);
block_2_ctile_map
);
}
}
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
static
constexpr
auto
GetNPerBlock
()
{
return
NPerBlock
;
}
static
std
::
string
GetTypeString
()
static
std
::
string
GetTypeString
()
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment