Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bf73d297
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "06c5a07255f7ad37032b17f17949065f14f84f8a"
Commit
bf73d297
authored
Feb 07, 2025
by
Adam Osewski
Browse files
Fixes
parent
a5e9069f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
189 additions
and
166 deletions
+189
-166
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+189
-166
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
bf73d297
...
@@ -28,18 +28,14 @@ template <typename GridwiseGemm,
...
@@ -28,18 +28,14 @@ template <typename GridwiseGemm,
typename
GemmDesc
,
typename
GemmDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
index_t
MinimumOccupancy
=
1
,
typename
BElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
TailNumber
TailNum
=
TailNumber
::
Full
>
typename
CDEElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
,
const
index_t
group_count
)
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
defined(__gfx94__))
...
@@ -68,13 +64,19 @@ __global__ void
...
@@ -68,13 +64,19 @@ __global__ void
group_id
=
index_t
((
left
+
right
)
/
2
);
group_id
=
index_t
((
left
+
right
)
/
2
);
}
}
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
const
auto
karg
=
gemm_desc_ptr
[
group_id
].
karg_
;
gemm_desc_ptr
[
group_id
].
karg_
,
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
static_cast
<
void
*>
(
p_shared
),
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
a_element_op
,
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
b_element_op
,
karg
.
p_b_grid
+
splitk_batch_offset
.
b_k_split_offset
,
c_element_op
);
karg
.
p_ds_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
);
#else
#else
ignore
=
gemm_descs_const
;
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
ignore
=
group_count
;
...
@@ -131,19 +133,24 @@ template <typename ALayout,
...
@@ -131,19 +133,24 @@ template <typename ALayout,
typename
ComputeTypeA
=
EDataType
,
typename
ComputeTypeA
=
EDataType
,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
ComputeTypeB
=
ComputeTypeA
,
bool
PermuteA
=
false
,
bool
PermuteA
=
false
,
bool
PermuteB
=
false
>
bool
PermuteB
=
false
,
// MultipleD not supported for now.
struct
DeviceGroupedGemmXdlSplitKCShuffle
:
public
DeviceGroupedGemmSplitK
<
ALayout
,
enable_if_t
<
is_same_v
<
DsLayout
,
ck
::
Tuple
<
>
>
&&
is_same_v
<
DsDataType
,
ck
::
Tuple
<>>
,
BLayout
,
bool
>
=
false
>
DsLayout
,
>
ELayout
,
ADataType
,
struct
DeviceGroupedGemmXdlSplitKCShuffle
BDataType
,
:
public
DeviceGroupedGemmSplitK
<
ALayout
,
DsDataType
,
BLayout
,
EDataType
,
DsLayout
,
AElementwiseOperation
,
ELayout
,
BElementwiseOperation
,
ADataType
,
CDEElementwiseOperation
>
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
@@ -198,7 +205,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -198,7 +205,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
BlkGemmPipeSched
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeA
,
ComputeTypeB
>
;
ComputeTypeB
,
PermuteA
,
PermuteB
>
;
using
Block2ETileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
Block2ETileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
...
@@ -209,16 +218,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -209,16 +218,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct
GemmTransKernelArg
struct
GemmTransKernelArg
{
{
KernelArgument
karg_
;
KernelArgument
karg_
;
//
GroupedGemmBlock2ETileMap block_2_ctile_map_;
GroupedGemmBlock2ETileMap
block_2_ctile_map_
;
index_t
block_start_
,
block_end_
;
index_t
block_start_
,
block_end_
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
KernelArgument
&&
karg
,
GemmTransKernelArg
(
KernelArgument
&&
karg
,
//
GroupedGemmBlock2ETileMap&& b2c_map,
GroupedGemmBlock2ETileMap
&&
b2c_map
,
index_t
block_start
,
index_t
block_start
,
index_t
block_end
)
index_t
block_end
)
:
karg_
{
karg
},
:
karg_
{
karg
},
//
block_2_ctile_map_{b2c_map},
block_2_ctile_map_
{
b2c_map
},
block_start_
{
block_start
},
block_start_
{
block_start
},
block_end_
{
block_end
}
block_end_
{
block_end
}
{
{
...
@@ -234,8 +243,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -234,8 +243,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Argument
(
std
::
vector
<
const
void
*>&
p_a_grid
,
Argument
(
std
::
vector
<
const
void
*>&
p_a_grid
,
std
::
vector
<
const
void
*>&
p_b_grid
,
std
::
vector
<
const
void
*>&
p_b_grid
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
GemmDesc
>&
gemm_descs
)
std
::
vector
<
GemmDesc
>&
gemm_descs
,
:
Argument
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
,
DefaultKBatch
)
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
))
:
Argument
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
,
DefaultKBatch
,
a_element_op
,
b_element_op
,
cde_element_op
)
{
{
// TODO: use occupancy api to calculate appropriate batch size.
// TODO: use occupancy api to calculate appropriate batch size.
}
}
...
@@ -244,7 +256,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -244,7 +256,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std
::
vector
<
const
void
*>&
p_b_grid
,
std
::
vector
<
const
void
*>&
p_b_grid
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
index_t
kbatch
)
index_t
kbatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
))
:
K_BATCH
{
kbatch
}
:
K_BATCH
{
kbatch
}
{
{
grid_size_
=
0
;
grid_size_
=
0
;
...
@@ -267,7 +282,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -267,7 +282,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
==
0
)
if
(
M
*
N
*
K
==
0
)
{
{
skipped_group_count_
++
;
skipped_group_count_
++
;
continue
;
continue
;
...
@@ -277,12 +292,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -277,12 +292,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
index_t
gdx
,
gdy
,
gdz
;
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
M
,
N
,
4
};
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
M
,
N
,
K_BATCH
);
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
grid_size_grp
*=
K_BATCH
;
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
gdx
,
gdy
,
gdz
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
// const index_t grid_size_grp = gdx * gdy * gdz;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
...
@@ -290,24 +302,27 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -290,24 +302,27 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// block-to-e-tile map
// block-to-e-tile map
//
auto grouped_block_2_ctile_map =
auto
grouped_block_2_ctile_map
=
//
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
KernelArgument
karg
{
type_convert
<
const
ADataType
*>
(
p_a_grid
[
i
]),
KernelArgument
karg
{
type_convert
<
const
ADataType
*>
(
p_a_grid
[
i
]),
type_convert
<
const
BDataType
*>
(
p_b_grid
[
i
]),
type_convert
<
const
BDataType
*>
(
p_b_grid
[
i
]),
{},
// p_ds_grid
type_convert
<
EDataType
*>
(
p_c_grid
[
i
]),
type_convert
<
EDataType
*>
(
p_c_grid
[
i
]),
M
,
M
,
N
,
N
,
K
,
K
,
stride_a
,
stride_a
,
stride_b
,
stride_b
,
{},
// StrideDs_
stride_c
,
stride_c
,
K_BATCH
};
K_BATCH
,
a_element_op
,
b_element_op
,
cde_element_op
};
// gemm_kernel_args_.emplace_back(
gemm_kernel_args_
.
emplace_back
(
// std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
std
::
move
(
karg
),
std
::
move
(
grouped_block_2_ctile_map
),
block_start
,
block_end
);
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
}
}
}
}
...
@@ -326,28 +341,22 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -326,28 +341,22 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
// const index_t m_padded = GridwiseGemm::CalculateMPadded(karg.M);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
M
,
N
,
4
};
// const index_t n_padded = GridwiseGemm::CalculateNPadded(karg.N);
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
grid_size_grp
*=
K_BATCH
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
KBatch
);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
gdx
,
gdy
,
gdz
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
);
// const index_t grid_size_grp = gdx * gdy * gdz;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// auto grouped_block_2_ctile_map =
// block-to-e-tile map
// GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
karg
.
KBatch
=
K_BATCH
;
karg
.
KBatch
=
K_BATCH
;
//
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
grouped_block_2_ctile_map
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
}
}
}
}
...
@@ -365,45 +374,53 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -365,45 +374,53 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
bool
all_have_main_k_block_loop
{
true
};
const
auto
&
karg0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
;
bool
all_have_kbatch_gt_one
;
index_t
k_grain0
=
karg0
.
KBatch
*
KPerBlock
;
index_t
K_split0
=
(
karg0
.
K
+
k_grain0
-
1
)
/
k_grain0
*
KPerBlock
;
bool
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split0
);
const
auto
tail_num
=
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split0
);
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
{
const
auto
&
karg
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
all_have_kbatch_gt_one
=
karg
.
KBatch
>
1
;
index_t
k_grain
=
arg
.
gemm_kernel_args_
[
i
].
karg_
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
arg
.
gemm_kernel_args_
[
i
].
karg_
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
karg
.
Print
();
karg
.
Print
();
}
}
auto
k
batch
=
karg
.
KBatch
;
const
auto
&
k
arg
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
index_t
k_grain
=
karg
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
karg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
bool
not_all_have_main_k0_block_loop_same
=
all_have_main_k_block_loop
xor
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
bool
not_all_have_tail_num_same
=
(
tail_num
==
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
));
if
(
not_all_have_main_k0_block_loop_same
)
{
{
std
::
ostringstream
err
;
std
::
ostringstream
err
;
err
<<
"
Group id: "
<<
i
<<
" has invalid GridwiseGemm settings!
"
<<
__FILE__
err
<<
"
Not all gemms have same value for main_k0_block_loop! in
"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
if
(
not_all_have_tail_num_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same TailNumber value! in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
not_all_have_kbatch_value_same
)
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
)
)
{
{
std
::
ostringstream
err
;
std
::
ostringstream
err
;
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
err
<<
"Group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
__FILE__
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
<<
", group [0], kbatch: "
<<
arg
.
gemm_kernel_args_
[
0
].
karg_
.
KBatch
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
}
}
...
@@ -418,64 +435,71 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -418,64 +435,71 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
float
ave_time
=
0
;
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
all_have_kbatch_gt_on
e
)
if
(
stream_config
.
flush_cach
e
)
{
{
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
const
auto
&
arg_
=
arg
.
gemm_kernel_args_
[
0
].
karg_
;
{
const
auto
&
karg
=
trans_arg
.
karg_
;
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
hip_check_error
(
hipMemsetAsync
(
karg
.
p_c_grid
,
arg_
.
M
,
arg_
.
MPadded
,
arg_
.
K
,
arg_
.
KPadded
,
arg_
.
StrideA
,
arg_
.
AK0
);
0
,
const
auto
b_grid_desc_bk0_n_bk1
=
GridwiseGemm
::
MakeBGridDescriptor_BK0_N_BK1
(
karg
.
M
*
karg
.
N
*
sizeof
(
EDataType
),
arg_
.
K
,
arg_
.
KPadded
,
arg_
.
N
,
arg_
.
NPadded
,
arg_
.
StrideB
,
arg_
.
BK0
);
stream_config
.
stream_id_
));
}
auto
size_a_buffer
=
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
);
auto
size_b_buffer
=
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
);
ck
::
utility
::
RotatingMemWrapper
<
Argument
>
rotating_mem
(
arg_
,
stream_config
.
rotating_count
,
size_a_buffer
,
size_b_buffer
);
rotating_mem
.
Print
();
auto
run_flush_cache
=
[
&
]()
{
// flush icache
ck
::
utility
::
flush_icache
();
// rotating mem
rotating_mem
.
Next
();
// clear c mem
// TODO: should be loop here through all groups
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
run_flush_cache
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
gemm_kernel_args_
.
size
());
}
}
else
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
{
{
const
auto
&
karg
=
trans_arg
.
karg_
;
// TODO: should be loop here through all groups
ave_time
+=
launch_and_time_kernel
(
if
(
arg
.
KBatch
>
1
)
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
karg
);
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
gemm_kernel_args_
.
size
());
}
}
};
};
constexpr
index_t
minimum_occupancy
=
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
// Calculate TailNumber for one
auto
calculate_tail_number
=
[
&
]()
{
index_t
k_grain
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
return
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
);
};
auto
all_have_same_tail_number
=
[
&
]()
{
// Calculate TailNumber for one
auto
tail_number
=
calculate_tail_number
();
// Calculate TailNumber for every other arg and compare
for
(
size_t
i
=
1
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
index_t
k_grain
=
arg
.
gemm_kernel_args_
[
i
].
karg_
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
arg
.
gemm_kernel_args_
[
i
].
karg_
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
if
(
tail_number
!=
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
))
{
return
false
;
}
}
return
true
;
};
auto
throw_error
=
[
&
]()
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same TailNumber value! "
;
throw
std
::
runtime_error
(
err
.
str
());
};
if
(
all_have_main_k_block_loop
)
if
(
all_have_main_k_block_loop
)
{
{
// Tail number always full
// Tail number always full
...
@@ -485,19 +509,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -485,19 +509,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if
(
all_have_kbatch_gt_one
)
if
(
all_have_kbatch_gt_one
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
true
,
GemmTransKernelArg
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
minimum_occupancy
>
;
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
else
else
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
true
,
GemmTransKernelArg
,
InMemoryDataOperationEnum
::
Set
,
true
,
minimum_occupancy
>
;
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
}
}
...
@@ -507,24 +533,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -507,24 +533,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
if
(
all_have_kbatch_gt_one
)
if
(
all_have_kbatch_gt_one
)
{
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
One
)
if
(
tail_num
==
TailNumber
::
One
)
{
{
if
(
all_have_same_tail_number
())
const
auto
kernel
=
{
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GemmTransKernelArg
true
,
GridwiseGemm
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
AtomicAdd
,
TailNumber
::
One
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
}
}
//// TODO: Fix below as above!
else
if
(
calculate_tail_number
()
==
TailNumber
::
Full
)
else
if
(
calculate_tail_number
()
==
TailNumber
::
Full
)
{
{
if
(
all_have_same_tail_number
())
if
(
all_have_same_tail_number
())
...
@@ -1094,11 +1115,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -1094,11 +1115,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
AElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
)
CDEElementwiseOperation
cde_element_op
)
{
{
return
Argument
{
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
};
return
Argument
{
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -1110,11 +1132,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -1110,11 +1132,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
)
override
CDEElementwiseOperation
cde_element_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
);
return
std
::
make_unique
<
Argument
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
}
// polymorphic
// polymorphic
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment