Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6e9ef894
Commit
6e9ef894
authored
Feb 17, 2025
by
rtmadduri
Browse files
applied changes to tail num lambda, clean up ctrs
parent
bf73d297
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
282 additions
and
406 deletions
+282
-406
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+282
-406
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
6e9ef894
...
...
@@ -76,7 +76,8 @@ __global__ void
karg
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
);
karg
.
c_element_op
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
);
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
...
...
@@ -137,20 +138,18 @@ template <typename ALayout,
// MultipleD not supported for now.
enable_if_t
<
is_same_v
<
DsLayout
,
ck
::
Tuple
<
>
>
&&
is_same_v
<
DsDataType
,
ck
::
Tuple
<>>
,
bool
>
=
false
>
>
struct
DeviceGroupedGemmXdlSplitKCShuffle
:
public
DeviceGroupedGemmSplitK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
struct
DeviceGroupedGemmXdlSplitKCShuffle
:
public
DeviceGroupedGemmSplitK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -221,7 +220,7 @@ template <typename ALayout,
GroupedGemmBlock2ETileMap
block_2_ctile_map_
;
index_t
block_start_
,
block_end_
;
GemmTransKernelArg
()
=
default
;
//
GemmTransKernelArg() = default;
GemmTransKernelArg
(
KernelArgument
&&
karg
,
GroupedGemmBlock2ETileMap
&&
b2c_map
,
index_t
block_start
,
...
...
@@ -243,11 +242,8 @@ template <typename ALayout,
Argument
(
std
::
vector
<
const
void
*>&
p_a_grid
,
std
::
vector
<
const
void
*>&
p_b_grid
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
))
:
Argument
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
,
DefaultKBatch
,
a_element_op
,
b_element_op
,
cde_element_op
)
std
::
vector
<
GemmDesc
>&
gemm_descs
)
:
Argument
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
,
DefaultKBatch
)
{
// TODO: use occupancy api to calculate appropriate batch size.
}
...
...
@@ -256,10 +252,7 @@ template <typename ALayout,
std
::
vector
<
const
void
*>&
p_b_grid
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
index_t
kbatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
))
index_t
kbatch
)
:
K_BATCH
{
kbatch
}
{
grid_size_
=
0
;
...
...
@@ -307,19 +300,14 @@ template <typename ALayout,
KernelArgument
karg
{
type_convert
<
const
ADataType
*>
(
p_a_grid
[
i
]),
type_convert
<
const
BDataType
*>
(
p_b_grid
[
i
]),
{},
// p_ds_grid
type_convert
<
EDataType
*>
(
p_c_grid
[
i
]),
M
,
N
,
K
,
stride_a
,
stride_b
,
{},
// StrideDs_
stride_c
,
K_BATCH
,
a_element_op
,
b_element_op
,
cde_element_op
};
K_BATCH
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
std
::
move
(
grouped_block_2_ctile_map
),
block_start
,
block_end
);
...
...
@@ -341,8 +329,8 @@ template <typename ALayout,
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
M
,
N
,
4
};
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
karg
.
M
,
karg
.
N
,
4
};
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
);
grid_size_grp
*=
K_BATCH
;
const
index_t
block_start
=
grid_size_
;
...
...
@@ -380,16 +368,17 @@ template <typename ALayout,
bool
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split0
);
const
auto
tail_num
=
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split0
);
bool
all_have_kbatch_gt_one
=
karg0
.
KBatch
>
1
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
karg
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
if
(
stream_config
.
log_level_
>
0
)
{
karg
.
Print
();
}
const
auto
&
karg
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
index_t
k_grain
=
karg
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
karg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
...
...
@@ -460,11 +449,16 @@ template <typename ALayout,
rotating_mem
.
Next
();
// clear c mem
// TODO: should be loop here through all groups
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
{
const
auto
&
karg
=
trans_arg
.
karg_
;
if
(
karg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
EDataType
),
stream_config
.
stream_id_
));
}
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
...
...
@@ -480,11 +474,15 @@ template <typename ALayout,
else
{
// TODO: should be loop here through all groups
if
(
arg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
{
const
auto
&
karg
=
trans_arg
.
karg_
;
if
(
karg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
EDataType
),
stream_config
.
stream_id_
));
}
ave_time
=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -537,7 +535,8 @@ template <typename ALayout,
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
true
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
One
>
;
...
...
@@ -546,313 +545,231 @@ template <typename ALayout,
//// TODO: Fix below as above!
else
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Full
)
else
if
(
tail_num
==
TailNumber
::
Full
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Two
)
if
(
tail_num
==
TailNumber
::
Two
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Three
)
if
(
tail_num
==
TailNumber
::
Three
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Four
)
if
(
tail_num
==
TailNumber
::
Four
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Five
)
if
(
tail_num
==
TailNumber
::
Five
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Six
)
if
(
tail_num
==
TailNumber
::
Six
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Seven
)
if
(
tail_num
==
TailNumber
::
Seven
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
One
)
if
(
tail_num
==
TailNumber
::
One
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Full
)
else
if
(
tail_num
==
TailNumber
::
Full
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Two
)
if
(
tail_num
==
TailNumber
::
Two
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Three
)
if
(
tail_num
==
TailNumber
::
Three
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Four
)
if
(
tail_num
==
TailNumber
::
Four
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Five
)
if
(
tail_num
==
TailNumber
::
Five
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Six
)
if
(
tail_num
==
TailNumber
::
Six
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Seven
)
if
(
tail_num
==
TailNumber
::
Seven
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
...
...
@@ -862,77 +779,57 @@ template <typename ALayout,
{
if
(
all_have_kbatch_gt_one
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Odd
)
if
(
tail_num
==
TailNumber
::
Odd
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Odd
)
if
(
tail_num
==
TailNumber
::
Odd
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
...
...
@@ -941,78 +838,57 @@ template <typename ALayout,
{
if
(
all_have_kbatch_gt_one
)
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Odd
)
if
(
tail_num
==
TailNumber
::
Odd
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
calculate_
tail_num
ber
()
==
TailNumber
::
Odd
)
if
(
tail_num
==
TailNumber
::
Odd
)
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
if
(
all_have_same_tail_number
())
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
else
{
throw_error
();
}
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
...
...
@@ -1025,19 +901,21 @@ template <typename ALayout,
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
...
...
@@ -1115,12 +993,11 @@ template <typename ALayout,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
)
{
return
Argument
{
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
};
return
Argument
{
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -1132,12 +1009,11 @@ template <typename ALayout,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
p_c_grid
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
return
std
::
make_unique
<
Argument
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
gemm_descs
);
}
// polymorphic
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment