Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b27909a0
Commit
b27909a0
authored
Jun 28, 2023
by
Jing Zhang
Browse files
clean
parent
e542dfc4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
369 deletions
+8
-369
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+8
-369
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
b27909a0
...
@@ -23,98 +23,7 @@ namespace ck {
...
@@ -23,98 +23,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
#if 0
template <typename GridwiseGemm,
typename GemmDesc,
typename GemmSharedArgs,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const GemmSharedArgs gemm_shared_args)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = block_id / gemm_shared_args.block_size;
#if 1
// const auto M = gemm_shared_args.M;
// const auto N = gemm_shared_args.N;
// const auto K = gemm_shared_args.K;
// const auto StrideA = gemm_shared_args.StrideA;
// const auto StrideB = gemm_shared_args.StrideB;
// const auto StrideC = gemm_shared_args.StrideC;
// const auto MPadded = gemm_shared_args.MPadded;
// const auto NPadded = gemm_shared_args.NPadded;
// const auto KPadded = gemm_shared_args.KPadded;
// const auto K0 = gemm_shared_args.KPadded;
// const auto k_batch = gemm_shared_args.k_batch;
const auto M = 2;
const auto N = 768;
const auto K = 4608;
const auto StrideA = 4608;
const auto StrideB = 4608;
const auto StrideC = 768;
const auto MPadded = 32;
const auto NPadded = 768;
const auto KPadded = 4608;
const auto K0 = 576;
const auto k_batch = 1;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
const index_t block_start = gemm_shared_args.block_size * group_id;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
auto grouped_block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
const auto block_2_ctile_map = grouped_block_2_ctile_map;
#endif
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
block_2_ctile_map);
#else
ignore = gemm_descs_const;
ignore = all_gemm_block_size;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#elif 1
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
typename
GemmDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
...
@@ -124,88 +33,7 @@ __global__ void
...
@@ -124,88 +33,7 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_size
)
const
index_t
group_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
for
(
index_t
group_id
=
0
;
group_id
<
group_size
;
group_id
++
)
{
const
auto
M
=
gemm_desc_ptr
[
group_id
].
karg_
.
M
;
const
auto
N
=
gemm_desc_ptr
[
group_id
].
karg_
.
N
;
const
auto
K
=
gemm_desc_ptr
[
group_id
].
karg_
.
K
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideB
;
const
auto
StrideC
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideC
;
const
auto
MPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
MPadded
;
const
auto
NPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
NPadded
;
const
auto
KPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
KPadded
;
const
auto
K0
=
gemm_desc_ptr
[
group_id
].
karg_
.
K0
;
const
auto
k_batch
=
gemm_desc_ptr
[
group_id
].
karg_
.
k_batch
;
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
static
constexpr
index_t
B2E_M01
=
8
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMapKSplit
>
;
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
const
auto
block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
0
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
.
p_a_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_b_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_c_grid
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
,
KPadded
,
K0
,
k_batch
,
static_cast
<
void
*>
(
p_shared
),
block_2_ctile_map
);
}
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#elif 0
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
#if 0
const index_t N,
const index_t K,
const index_t StrideB,
const index_t NPadded,
const index_t KPadded,
const index_t K0,
const index_t k_batch,
#endif
const
index_t
block_size
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
defined(__gfx940__))
...
@@ -216,7 +44,6 @@ __global__ void
...
@@ -216,7 +44,6 @@ __global__ void
const
auto
gemm_desc_ptr
=
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
#if 0
index_t
left
=
0
;
index_t
left
=
0
;
index_t
right
=
group_count
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
...
@@ -234,13 +61,7 @@ __global__ void
...
@@ -234,13 +61,7 @@ __global__ void
}
}
group_id
=
index_t
((
left
+
right
)
/
2
);
group_id
=
index_t
((
left
+
right
)
/
2
);
}
}
#else
const
index_t
group_id
=
block_id
/
block_size
;
#endif
#if 1
#if 1
const
auto
M
=
gemm_desc_ptr
[
group_id
].
karg_
.
M
;
const
auto
M
=
gemm_desc_ptr
[
group_id
].
karg_
.
M
;
const
auto
N
=
gemm_desc_ptr
[
group_id
].
karg_
.
N
;
const
auto
N
=
gemm_desc_ptr
[
group_id
].
karg_
.
N
;
const
auto
K
=
gemm_desc_ptr
[
group_id
].
karg_
.
K
;
const
auto
K
=
gemm_desc_ptr
[
group_id
].
karg_
.
K
;
...
@@ -252,21 +73,11 @@ __global__ void
...
@@ -252,21 +73,11 @@ __global__ void
const
auto
KPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
KPadded
;
const
auto
KPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
KPadded
;
const
auto
K0
=
gemm_desc_ptr
[
group_id
].
karg_
.
K0
;
const
auto
K0
=
gemm_desc_ptr
[
group_id
].
karg_
.
K0
;
const
auto
k_batch
=
gemm_desc_ptr
[
group_id
].
karg_
.
k_batch
;
const
auto
k_batch
=
gemm_desc_ptr
[
group_id
].
karg_
.
k_batch
;
#else
const
auto
M
=
gemm_desc_ptr
[
group_id
].
karg_
.
M
;
const
auto
MPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
MPadded
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideA
;
const
auto
StrideC
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideC
;
#endif
// const auto block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_;
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
static
constexpr
index_t
B2E_M01
=
8
;
static
constexpr
index_t
B2E_M01
=
8
;
const
index_t
block_start
=
block_size
*
group_id
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
Block2ETileMapKSplit
=
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
...
@@ -274,7 +85,7 @@ __global__ void
...
@@ -274,7 +85,7 @@ __global__ void
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
const
auto
block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
const
auto
block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
gemm_desc_ptr
[
group_id
].
block_start
_
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
.
p_a_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_a_grid
,
...
@@ -293,20 +104,11 @@ __global__ void
...
@@ -293,20 +104,11 @@ __global__ void
k_batch
,
k_batch
,
static_cast
<
void
*>
(
p_shared
),
static_cast
<
void
*>
(
p_shared
),
block_2_ctile_map
);
block_2_ctile_map
);
#else
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
,
static_cast
<
void
*>
(
p_shared
),
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
);
#endif
#else
#else
ignore
=
gemm_descs_const
;
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
ignore
=
group_count
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
#endif
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -432,16 +234,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -432,16 +234,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct
GemmTransKernelArg
struct
GemmTransKernelArg
{
{
KernelArgument
karg_
;
KernelArgument
karg_
;
// GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t
block_start_
,
block_end_
;
index_t
block_start_
,
block_end_
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
KernelArgument
&&
karg
,
GemmTransKernelArg
(
KernelArgument
&&
karg
,
// GroupedGemmBlock2ETileMap&& b2c_map,
index_t
block_start
,
index_t
block_start
,
index_t
block_end
)
index_t
block_end
)
:
karg_
{
karg
},
:
karg_
{
karg
},
// block_2_ctile_map_{b2c_map},
block_start_
{
block_start
},
block_start_
{
block_start
},
block_end_
{
block_end
}
block_end_
{
block_end
}
{
{
...
@@ -511,10 +310,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -511,10 +310,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
block_start
=
0
;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_grp
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
=
grid_size_grp
;
grid_size_
+
=
grid_size_grp
;
// block-to-e-tile map
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
auto
grouped_block_2_ctile_map
=
...
@@ -567,10 +366,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -567,10 +366,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
block_start
=
0
;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_grp
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
=
grid_size_grp
;
grid_size_
+
=
grid_size_grp
;
// block-to-e-tile map
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
auto
grouped_block_2_ctile_map
=
...
@@ -645,103 +444,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -645,103 +444,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
}
}
struct
ArgumentMsN1K1
{
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
EDataType
*
p_c_grid
;
// index_t M;
// index_t StrideA;
// index_t StrideC;
// index_t MPadded;
// GroupedGemmBlock2ETileMap block_2_ctile_map;
};
struct
GemmTransKernelArgMsN1K1
{
ArgumentMsN1K1
karg_
;
};
#if 1
std
::
vector
<
GemmTransKernelArgMsN1K1
>
gemm_kernel_args_msn1k1_
;
// index_t all_gemm_block_size =
// arg.gemm_kernel_args_[0].block_end_ - arg.gemm_kernel_args_[0].block_start_;
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
{
auto
karg
=
ArgumentMsN1K1
{
trans_arg
.
karg_
.
p_a_grid
,
trans_arg
.
karg_
.
p_b_grid
,
trans_arg
.
karg_
.
p_c_grid
};
auto
block_size
=
trans_arg
.
block_end_
-
trans_arg
.
block_start_
;
std
::
cout
<<
"trans_arg.block_start_: "
<<
trans_arg
.
block_start_
<<
" trans_arg.block_end_: "
<<
trans_arg
.
block_end_
<<
" block_size: "
<<
block_size
<<
std
::
endl
;
gemm_kernel_args_msn1k1_
.
push_back
({
karg
});
}
#endif
#if 1
hip_check_error
(
hipMemcpy
(
arg
.
p_workspace_
,
hip_check_error
(
hipMemcpy
(
arg
.
p_workspace_
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
),
hipMemcpyHostToDevice
));
hipMemcpyHostToDevice
));
#else
struct
GemmSharedArgs
{
index_t
block_size
;
// index_t M;
// index_t N;
// index_t K;
// index_t StrideA;
// index_t StrideB;
// index_t StrideC;
// index_t MPadded;
// index_t NPadded;
// index_t KPadded;
// index_t K0;
// index_t k_batch;
// GroupedGemmBlock2ETileMap block_2_ctile_map;
#if 0
void print()
{
std::cout << "block_size = " << block_size << " M = " << M << " N = " << N
<< " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC
<< " MPadded = " << MPadded << " NPadded = " << NPadded
<< " KPadded = " << KPadded << " K0 = " << K0
<< " k_batch = " << k_batch << std::endl;
}
#endif
};
auto
shared_karg
=
GemmSharedArgs
{
all_gemm_block_size
,
// arg.gemm_kernel_args_[0].karg_.M,
// arg.gemm_kernel_args_[0].karg_.N,
// arg.gemm_kernel_args_[0].karg_.K,
// arg.gemm_kernel_args_[0].karg_.StrideA,
// arg.gemm_kernel_args_[0].karg_.StrideB,
// arg.gemm_kernel_args_[0].karg_.StrideC,
// arg.gemm_kernel_args_[0].karg_.MPadded,
// arg.gemm_kernel_args_[0].karg_.NPadded,
// arg.gemm_kernel_args_[0].karg_.KPadded,
// arg.gemm_kernel_args_[0].karg_.K0,
// arg.gemm_kernel_args_[0].karg_.k_batch,
// arg.gemm_kernel_args_[0].block_2_ctile_map_,
};
// shared_karg.print();
hip_check_error
(
hipMemcpy
(
arg
.
p_workspace_
,
gemm_kernel_args_msn1k1_
.
data
(),
gemm_kernel_args_msn1k1_
.
size
()
*
sizeof
(
GemmTransKernelArgMsN1K1
),
hipMemcpyHostToDevice
));
#endif
float
ave_time
=
0
;
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
...
@@ -762,25 +468,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -762,25 +468,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
#if 0
arg.gemm_kernel_args_[0].karg_.N,
arg.gemm_kernel_args_[0].karg_.K,
arg.gemm_kernel_args_[0].karg_.StrideB,
arg.gemm_kernel_args_[0].karg_.NPadded,
arg.gemm_kernel_args_[0].karg_.KPadded,
arg.gemm_kernel_args_[0].karg_.K0,
arg.gemm_kernel_args_[0].karg_.k_batch,
#elif
0
all_gemm_block_size
#elif 1
arg
.
gemm_kernel_args_
.
size
()
arg
.
gemm_kernel_args_
.
size
()
#endif
);
);
};
};
std
::
cout
<<
"all_have_main_k0_block_loop: "
<<
all_have_main_k0_block_loop
<<
" all_have_kbatch_gt_one: "
<<
all_have_kbatch_gt_one
<<
std
::
endl
;
#if 1
if
(
all_have_main_k0_block_loop
)
if
(
all_have_main_k0_block_loop
)
{
{
if
(
all_have_kbatch_gt_one
)
if
(
all_have_kbatch_gt_one
)
...
@@ -827,58 +518,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -827,58 +518,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Run
(
kernel
);
Run
(
kernel
);
}
}
}
}
#else
if
(
all_have_main_k0_block_loop
)
{
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArgMsN1K1
,
GemmSharedArgs
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArgMsN1K1
,
GemmSharedArgs
,
true
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
}
else
{
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArgMsN1K1
,
GemmSharedArgs
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArgMsN1K1
,
GemmSharedArgs
,
false
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
}
#endif
return
ave_time
;
return
ave_time
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment