Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ee1c88b1
Commit
ee1c88b1
authored
Jul 27, 2023
by
Jing Zhang
Browse files
add setkbatch
parent
41a1466a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
129 additions
and
77 deletions
+129
-77
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+8
-4
include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
...sor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+104
-59
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+16
-14
No files found.
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
ee1c88b1
...
@@ -57,7 +57,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
...
@@ -57,7 +57,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
2
,
0
,
1
,
3
>
,
S
<
2
,
0
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
2
,
0
,
1
,
3
>
,
S
<
2
,
0
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
// clang-format on
struct
ProblemSize
final
struct
ProblemSize
final
...
@@ -77,6 +78,7 @@ struct ExecutionConfig final
...
@@ -77,6 +78,7 @@ struct ExecutionConfig final
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
int
k_batch
=
1
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
};
};
...
@@ -238,6 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -238,6 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
}
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
gemm
.
SetKBatch
(
argument
,
config
.
k_batch
);
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
...
@@ -293,8 +296,7 @@ int main(int argc, char* argv[])
...
@@ -293,8 +296,7 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
118
,
0
,
1
,
148
};
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
...
@@ -306,17 +308,19 @@ int main(int argc, char* argv[])
...
@@ -306,17 +308,19 @@ int main(int argc, char* argv[])
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
}
}
if
(
argc
==
4
)
if
(
argc
==
5
)
{
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
printf
(
"arg4: k_batch (> 0)
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
View file @
ee1c88b1
...
@@ -54,6 +54,7 @@ struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm<ALayout,
...
@@ -54,6 +54,7 @@ struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm<ALayout,
CElementwiseOperation
>
CElementwiseOperation
>
{
{
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
=
0
;
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
=
0
;
virtual
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
=
0
;
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
ee1c88b1
...
@@ -36,6 +36,7 @@ template <typename GridwiseGemm,
...
@@ -36,6 +36,7 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
...
@@ -103,25 +104,29 @@ __global__ void
...
@@ -103,25 +104,29 @@ __global__ void
const
auto
block_2_etile_map
=
const
auto
block_2_etile_map
=
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
,
id_off
);
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
,
id_off
);
GridwiseGemm
::
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
template
Run
<
HasMainKBlockLoop
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
EGlobalMemoryDataOperation
,
gemm_desc_ptr
[
group_id
].
p_a_grid
,
GemmSpec
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
ALayout
,
p_ds_grid_
,
BLayout
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
DsLayout
,
p_shared
,
ELayout
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
a_element_op
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
b_element_op
,
p_ds_grid_
,
c_element_op
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
M
,
p_shared
,
N
,
a_element_op
,
K
,
b_element_op
,
StrideA
,
c_element_op
,
StrideB
,
M
,
StrideDs
,
N
,
StrideE
,
K
,
KBatch
,
StrideA
,
block_2_etile_map
);
StrideB
,
StrideDs
,
StrideE
,
KBatch
,
block_2_etile_map
);
id_off
+=
grid_size_grp
;
id_off
+=
grid_size_grp
;
}
}
...
@@ -195,8 +200,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -195,8 +200,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
const
index_t
k_batch
=
2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -211,7 +214,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -211,7 +214,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
AtomicAdd
,
NumPrefetch
,
// NumGemmKPrefetchStage
NumPrefetch
,
// NumGemmKPrefetchStage
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
...
@@ -406,6 +408,33 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -406,6 +408,33 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
void
UpdateKBatch
(
index_t
k_batch
)
{
k_batch_
=
k_batch
;
if
(
k_batch_
<
1
)
{
throw
std
::
runtime_error
(
"wrong! k_batch must be > 0"
);
}
const
index_t
AverM
=
sum_of_m
/
group_count_
;
const
index_t
StrideE
=
gemm_desc_kernel_arg_
[
0
].
StrideE_
;
const
index_t
N
=
gemm_desc_kernel_arg_
[
0
].
N_
;
const
auto
e_grid_desc_m_n
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
AverM
,
N
,
StrideE
);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
k_batch_
};
grid_size_grp_
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
grid_size_
=
grid_size_grp_
*
group_count_
;
}
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
...
@@ -418,6 +447,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -418,6 +447,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{
{
grid_size_
=
0
;
grid_size_
=
0
;
k_batch_
=
1
;
grouped_gemm_kernel_args_dev
=
nullptr
;
grouped_gemm_kernel_args_dev
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
...
@@ -497,19 +528,16 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -497,19 +528,16 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
AverM
,
N
,
StrideE
);
AverM
,
N
,
StrideE
);
// block-to-e-tile map
// block-to-e-tile map
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
k_batch
};
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
k_batch_
};
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
// std::cout << "group_id: " << group_id << " grid_size_grp: " << grid_size_grp
grid_size_grp_
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
//<< std::endl;
if
(
group_id
*
grid_size_grp
!=
grid_size_
)
if
(
group_id
*
grid_size_grp
_
!=
grid_size_
)
{
{
throw
std
::
runtime_error
(
"wrong! grid_size_grp is not identical!"
);
throw
std
::
runtime_error
(
"wrong! grid_size_grp
_
is not identical!"
);
}
}
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
_
;
// check block-to-E-tile
// check block-to-E-tile
if
(
!
local_b2c_tile_map
.
CheckValidity
(
e_grid_desc_m_n
))
if
(
!
local_b2c_tile_map
.
CheckValidity
(
e_grid_desc_m_n
))
...
@@ -557,8 +585,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -557,8 +585,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const
void
*
grouped_gemm_kernel_args_dev
;
const
void
*
grouped_gemm_kernel_args_dev
;
index_t
grid_size_
;
index_t
grid_size_
;
index_t
grid_size_grp
;
index_t
grid_size_grp
_
;
index_t
sum_of_m
;
index_t
sum_of_m
;
index_t
k_batch_
;
};
};
// Invoker
// Invoker
...
@@ -570,37 +600,25 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -570,37 +600,25 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{
{
bool
has_main_k_block_loop
=
true
;
bool
has_main_k_block_loop
=
true
;
std
::
vector
<
GroupedGemmKernelArgument
<
NumDTensor
>>
grouped_gemm_kernel_args
;
grouped_gemm_kernel_args
.
reserve
(
arg
.
gemm_desc_kernel_arg_
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
{
const
auto
KPad
=
const
auto
KPad
=
GridwiseGemm
::
CalculateKPadded
(
arg
.
gemm_desc_kernel_arg_
[
i
].
K_
,
k_batch
);
GridwiseGemm
::
CalculateKPadded
(
arg
.
gemm_desc_kernel_arg_
[
i
].
K_
,
arg
.
k_batch
_
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
KPad
)
!=
has_main_k_block_loop
)
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
KPad
)
!=
has_main_k_block_loop
)
{
{
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
}
}
}
grouped_gemm_kernel_args
.
push_back
(
if
(
arg
.
grouped_gemm_kernel_args_dev
==
nullptr
)
GroupedGemmKernelArgument
<
NumDTensor
>
{
arg
.
gemm_desc_kernel_arg_
[
i
].
a_ptr_
,
{
arg
.
gemm_desc_kernel_arg_
[
i
].
b_ptr_
,
throw
std
::
runtime_error
(
"wrong! grouped_gemm_kernel_args_dev is nullpr"
);
arg
.
gemm_desc_kernel_arg_
[
i
].
ds_ptr_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
e_ptr_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
M_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
N_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
K_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideA_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideB_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideDs_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideE_
});
}
}
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
e_global_memory_operation_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdl_fixed_nk
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_fixed_nk
<
GridwiseGemm
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
...
@@ -615,13 +633,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -615,13 +633,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
e_global_memory_operation_
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
>
;
if
(
arg
.
grouped_gemm_kernel_args_dev
==
nullptr
)
{
throw
std
::
runtime_error
(
"wrong! grouped_gemm_kernel_args_dev is nullpr"
);
}
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
kernel
,
...
@@ -630,20 +644,43 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -630,20 +644,43 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
grouped_gemm_kernel_args_dev
),
cast_pointer_to_constant_address_space
(
arg
.
grouped_gemm_kernel_args_dev
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
grid_size_grp
,
arg
.
grid_size_grp
_
,
k_batch
,
arg
.
k_batch
_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
};
};
if
(
has_main_k_block_loop
)
constexpr
auto
AtomicAdd
=
InMemoryDataOperationEnum
::
AtomicAdd
;
constexpr
auto
Set
=
InMemoryDataOperationEnum
::
Set
;
if
(
arg
.
k_batch_
>
1
)
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
AtomicAdd
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
AtomicAdd
>
{});
}
}
}
else
else
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
Set
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
InMemoryDataOperationEnum
,
Set
>
{});
}
}
}
return
ave_time
;
return
ave_time
;
...
@@ -775,6 +812,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -775,6 +812,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
);
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
);
}
}
static
void
SetKBatch
(
Argument
&
arg
,
index_t
k_batch
)
{
arg
.
UpdateKBatch
(
k_batch
);
}
// polymorphic
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
override
{
return
SetKBatch
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
k_batch
);
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
ee1c88b1
...
@@ -37,7 +37,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
...
@@ -37,7 +37,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -632,6 +631,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -632,6 +631,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
__device__
__host__
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
__device__
__host__
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
AGridDesc_KBatch_AK0_M_AK1
,
typename
AGridDesc_KBatch_AK0_M_AK1
,
typename
BGridDesc_KBatch_BK0_N_BK1
,
typename
BGridDesc_KBatch_BK0_N_BK1
,
typename
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -1074,6 +1074,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -1074,6 +1074,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -1139,19 +1140,20 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -1139,19 +1140,20 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
>
(
p_b_grid
,
p_a_grid
,
p_ds_grid
,
p_b_grid
,
p_e_grid
,
p_ds_grid
,
p_shared
,
p_e_grid
,
a_element_op
,
p_shared
,
b_element_op
,
a_element_op
,
cde_element_op
,
b_element_op
,
a_grid_desc_kbatch_ak0_m_ak1
,
cde_element_op
,
b_grid_desc_kbatch_bk0_n_bk1
,
a_grid_desc_kbatch_ak0_m_ak1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
b_grid_desc_kbatch_bk0_n_bk1
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment