Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0ea428c3
Commit
0ea428c3
authored
Mar 19, 2024
by
Jakub Piasecki
Browse files
tmp save
parent
3e4d0ff3
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
637 additions
and
67 deletions
+637
-67
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+1
-28
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
..._grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
+70
-28
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
...tion/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
+136
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+150
-11
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp
...erence_tensor_operation/cpu/reference_gemm_multiple_d.hpp
+170
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+14
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
...tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
..._d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
+95
-0
profiler/src/profile_gemm_fixed_nk.cpp
profiler/src/profile_gemm_fixed_nk.cpp
+0
-0
No files found.
example/15_grouped_gemm/CMakeLists.txt
View file @
0ea428c3
add_custom_target
(
example_grouped_gemm_xdl
)
add_custom_target
(
example_grouped_gemm_xdl
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32
)
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_
splitk_x
dl_fp16.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16
)
add_example_executable
(
example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int4
)
endif
()
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
View file @
0ea428c3
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm
_multiple_d
.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -36,24 +36,28 @@ using Row = ck::tensor_layout::gemm::RowMajor;
...
@@ -36,24 +36,28 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddAdd
=
ck
::
tensor_operation
::
element_wise
::
AddAdd
;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
DDataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
DDataType
,
DDataType
>
;
using
EDataType
=
F32
;
using
EDataType
=
F32
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Col
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
DLayout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<
DLayout
,
DLayout
>
;
using
ELayout
=
Row
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
PassThrough
;
using
CDEElementOp
=
AddAdd
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
int
NumDMatrices
=
2
;
// using DeviceGemmInstance =
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -75,6 +79,7 @@ struct ProblemSize final
...
@@ -75,6 +79,7 @@ struct ProblemSize final
std
::
vector
<
ck
::
index_t
>
stride_As
;
std
::
vector
<
ck
::
index_t
>
stride_As
;
std
::
vector
<
ck
::
index_t
>
stride_Bs
;
std
::
vector
<
ck
::
index_t
>
stride_Bs
;
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
stride_Ds
;
std
::
vector
<
ck
::
index_t
>
stride_Cs
;
std
::
vector
<
ck
::
index_t
>
stride_Cs
;
ck
::
index_t
group_count
;
ck
::
index_t
group_count
;
...
@@ -85,7 +90,7 @@ struct ExecutionConfig final
...
@@ -85,7 +90,7 @@ struct ExecutionConfig final
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
int
k_batch
=
128
;
int
k_batch
=
128
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
};
};
bool
run_grouped_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
bool
run_grouped_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
...
@@ -97,10 +102,12 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -97,10 +102,12 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std
::
vector
<
void
*>
p_Cs
;
std
::
vector
<
void
*>
p_Cs
;
std
::
vector
<
const
void
*>
p_As
;
std
::
vector
<
const
void
*>
p_As
;
std
::
vector
<
const
void
*>
p_Bs
;
std
::
vector
<
const
void
*>
p_Bs
;
std
::
vector
<
std
::
array
<
const
void
*
,
NumDMatrices
>>
p_Ds
=
{};
gemm_descs
.
reserve
(
group_count
);
gemm_descs
.
reserve
(
group_count
);
p_As
.
reserve
(
group_count
);
p_As
.
reserve
(
group_count
);
p_Bs
.
reserve
(
group_count
);
p_Bs
.
reserve
(
group_count
);
p_Ds
.
reserve
(
group_count
);
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
@@ -118,20 +125,24 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -118,20 +125,24 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
std
::
array
<
Tensor
<
DDataType
>
,
NumDMatrices
>>
d_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_device_result_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_device_result_tensors
;
a_tensors
.
reserve
(
group_count
);
a_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
d_tensors
.
reserve
(
group_count
);
c_host_tensors
.
reserve
(
group_count
);
c_host_tensors
.
reserve
(
group_count
);
c_device_result_tensors
.
reserve
(
group_count
);
c_device_result_tensors
.
reserve
(
group_count
);
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
c_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
c_tensors_device
;
std
::
vector
<
std
::
vector
<
DeviceMemPtr
>>
d_tensors_device
;
a_tensors_device
.
reserve
(
group_count
);
a_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
d_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
std
::
size_t
flop
=
0
,
num_btype
=
0
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
...
@@ -142,6 +153,14 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -142,6 +153,14 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
problem_size
.
Ms
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
ALayout
{})));
problem_size
.
Ms
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ks
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Bs
[
i
],
BLayout
{})));
problem_size
.
Ks
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Bs
[
i
],
BLayout
{})));
auto
d0_tensor
=
Tensor
<
DDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
DLayout
{}));
auto
d1_tensor
=
Tensor
<
DDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
DLayout
{}));
std
::
array
<
Tensor
<
DDataType
>
,
NumDMatrices
>
d_tens
=
{
d0_tensor
,
d1_tensor
};
d_tensors
.
push_back
(
d_tens
);
c_host_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
c_host_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
ELayout
{})));
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
ELayout
{})));
c_device_result_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
c_device_result_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
...
@@ -153,6 +172,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -153,6 +172,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
flop
+=
std
::
size_t
(
2
)
*
problem_size
.
Ms
[
i
]
*
problem_size
.
Ks
[
i
]
*
problem_size
.
Ns
[
i
];
flop
+=
std
::
size_t
(
2
)
*
problem_size
.
Ms
[
i
]
*
problem_size
.
Ks
[
i
]
*
problem_size
.
Ns
[
i
];
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
GetElementSize
()
+
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
GetElementSize
()
+
sizeof
(
DDataType
)
*
d_tensors
[
i
][
0
].
GetElementSize
()
*
NumDMatrices
+
sizeof
(
EDataType
)
*
c_device_result_tensors
[
i
].
GetElementSize
();
sizeof
(
EDataType
)
*
c_device_result_tensors
[
i
].
GetElementSize
();
switch
(
config
.
init_method
)
switch
(
config
.
init_method
)
...
@@ -161,14 +181,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -161,14 +181,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
case
1
:
case
1
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
5
,
5
});
}
break
;
break
;
case
2
:
case
2
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
}
break
;
break
;
default:
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
}
}
}
}
}
...
@@ -183,44 +212,46 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -183,44 +212,46 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
c_device_result_tensors
[
i
].
GetElementSpaceSize
()
*
sizeof
(
EDataType
)));
c_device_result_tensors
[
i
].
GetElementSpaceSize
()
*
sizeof
(
EDataType
)));
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
d_tensors_device
[
i
].
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
d_tensors
[
i
][
j
].
GetElementSpaceSize
()
*
sizeof
(
DDataType
)));
}
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
d_tensors_device
[
i
][
j
]
->
ToDevice
(
d_tensors
[
i
][
j
].
mData
.
data
());
}
c_tensors_device
[
i
]
->
SetZero
();
c_tensors_device
[
i
]
->
SetZero
();
p_As
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_As
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_Bs
.
push_back
(
b_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_Bs
.
push_back
(
b_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_Ds
.
push_back
({
d_tensors_device
[
i
][
0
]
->
GetDeviceBuffer
(),
d_tensors_device
[
i
][
1
]
->
GetDeviceBuffer
()});
p_Cs
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_Cs
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
gemm_descs
.
push_back
({
problem_size
.
Ms
[
i
],
gemm_descs
.
push_back
({
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
problem_size
.
stride_As
[
i
],
problem_size
.
stride_Bs
[
i
],
problem_size
.
stride_Bs
[
i
],
problem_size
.
stride_Cs
[
i
],
problem_size
.
stride_Cs
[
i
],
{}
});
problem_size
.
stride_Ds
[
i
]
});
}
}
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CDEElementOp
{};
auto
c
de
_element_op
=
CDEElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
p_Ds
=
{};
//std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
c
de
_element_op
);
gemm
.
SetKBatchSize
(
argument
,
config
.
k_batch
);
gemm
.
SetKBatchSize
(
argument
,
config
.
k_batch
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
DeviceMem
gemm_workspace_dev
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
gemm_workspace_dev
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_workspace_dev
.
GetDeviceBuffer
());
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_workspace_dev
.
GetDeviceBuffer
());
...
@@ -242,16 +273,17 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -242,16 +273,17 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
MultipleD
<
ADataType
,
BDataType
,
BDataType
,
DsDataType
,
EDataType
,
EDataType
,
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CDEElementOp
>
;
CDEElementOp
>
;
float
*
p_workspace_dev
=
reinterpret_cast
<
float
*>
(
gemm_workspace_dev
.
GetDeviceBuffer
());
//
float* p_workspace_dev = reinterpret_cast<float*>(gemm_workspace_dev.GetDeviceBuffer());
std
::
size_t
gemm_offset
{
0
};
//
std::size_t gemm_offset{0};
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
{
...
@@ -266,29 +298,30 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -266,29 +298,30 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
// "\n"
// "\n"
// << std::endl;
// << std::endl;
hip_check_error
(
hipMemcpy
(
dev_res_tensor
.
data
(),
//
hip_check_error(hipMemcpy(dev_res_tensor.data(),
p_workspace_dev
+
gemm_offset
,
//
p_workspace_dev + gemm_offset,
dev_res_tensor
.
GetElementSpaceSizeInBytes
(),
//
dev_res_tensor.GetElementSpaceSizeInBytes(),
hipMemcpyDeviceToHost
));
//
hipMemcpyDeviceToHost));
// hip_check_error(hipDeviceSynchronize());
// hip_check_error(hipDeviceSynchronize());
//
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(),
c_tensors_device
[
i
]
->
FromDevice
(
c_device_result_tensors
[
i
].
mData
.
data
(),
//
c_device_result_tensors[i].mDesc.GetElementSize() *
c_device_result_tensors
[
i
].
mDesc
.
GetElementSize
()
*
//
sizeof(EDataType));
sizeof
(
EDataType
));
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_tensors
[
i
],
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_tensors
[
i
],
b_tensors
[
i
],
b_tensors
[
i
],
d_tensors
[
i
],
c_host_tensors
[
i
],
c_host_tensors
[
i
],
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c
de
_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
// pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
// pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
pass
&=
ck
::
utils
::
check_err
(
dev_res_tensor
,
c_host_tensors
[
i
]);
pass
&=
ck
::
utils
::
check_err
(
c_
dev
ice
_res
ult
_tensor
s
[
i
]
,
c_host_tensors
[
i
]);
gemm_offset
+=
argument
.
GetWorkspaceSize
(
i
);
//
gemm_offset += argument.GetWorkspaceSize(i);
}
}
std
::
cout
<<
"Verification: "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
"!"
<<
std
::
endl
;
std
::
cout
<<
"Verification: "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
"!"
<<
std
::
endl
;
...
@@ -332,6 +365,11 @@ int main(int argc, char* argv[])
...
@@ -332,6 +365,11 @@ int main(int argc, char* argv[])
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
problem_size
.
stride_Ds
.
push_back
({});
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
problem_size
.
stride_Ds
[
i
].
push_back
(
problem_size
.
Ns
[
i
]);
}
}
}
std
::
cout
std
::
cout
...
@@ -359,6 +397,10 @@ int main(int argc, char* argv[])
...
@@ -359,6 +397,10 @@ int main(int argc, char* argv[])
problem_size
.
stride_Bs
=
argToIntArray
(
argv
[
8
]);
problem_size
.
stride_Bs
=
argToIntArray
(
argv
[
8
]);
problem_size
.
stride_Cs
=
argToIntArray
(
argv
[
9
]);
problem_size
.
stride_Cs
=
argToIntArray
(
argv
[
9
]);
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
problem_size
.
stride_Ds
.
push_back
(
problem_size
.
stride_Cs
);
}
problem_size
.
group_count
=
problem_size
.
Ms
.
size
();
problem_size
.
group_count
=
problem_size
.
Ms
.
size
();
}
}
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
0 → 100644
View file @
0ea428c3
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmMultipleDKernelArguments
{
__host__
__device__
GroupedGemmMultipleDKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{
p_ds_grid_
},
p_e_grid
{
p_e_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideE
{
StrideE_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
void
Print
()
const
{
std
::
stringstream
str
;
for
(
auto
sd
:
StrideDs
)
str
<<
sd
<<
","
;
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SE:"
<<
StrideE
<<
", "
<<
"SDs: {"
<<
str
.
str
()
<<
"}"
<<
"}"
<<
std
::
endl
;
}
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGroupedGemmMultipleDSplitK
:
public
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
0ea428c3
...
@@ -14,15 +14,19 @@
...
@@ -14,15 +14,19 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include <ck/utility/loop_scheduler.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp>
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -100,7 +104,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -100,7 +104,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
static
constexpr
index_t
K0PerBlock
=
KPerBlock
/
AK1
;
static
constexpr
index_t
K0PerBlock
=
KPerBlock
/
AK1
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Workspace
E
DataType
=
float
;
using
WorkspaceDataType
=
float
;
// First stage GridwiseGEMM kernel.
// First stage GridwiseGEMM kernel.
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
...
@@ -108,7 +112,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -108,7 +112,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
ADataType
,
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
Workspace
E
DataType
,
WorkspaceDataType
,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
ELayout
,
ELayout
,
...
@@ -149,7 +153,95 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -149,7 +153,95 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
PipelineVer
,
PipelineVer
,
ComputeDataType
>
;
ComputeDataType
>
;
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideE
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELay
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideE
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELay
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideE
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
MakeEGridDescriptor_M_N
<
DLayout
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
},
Number
<
NumDTensor
>
{});
}
static
constexpr
auto
MakeDsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
static_cast
<
const
DDataType
*>
(
nullptr
);
},
Number
<
NumDTensor
>
{});
}
static
constexpr
auto
MakeElementwiseInputSequence
()
{
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
Number
<
i
+
1
-
i
>
{};
},
Number
<
NumDTensor
+
1
>
{});
//CShuffleNXdlPerWavePerShuffle
}
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
EGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
DsGridDesc_M_N
=
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}));
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
using
CDGridDesc_M_N
=
decltype
(
concat_tuple
(
ck
::
Tuple
<
CGridDesc_M_N
>
{},
DsGridDesc_M_N
{}));
//using CDDataTypes = decltype(concat_tuple(ck::Tuple<WorkspaceDataType*>{}, DsDataType{}));
using
CDDataTypes
=
decltype
(
concat_tuple
(
ck
::
Tuple
<
WorkspaceDataType
*>
{},
DsGridPointer
{}));
using
ElementwiseInputSequence
=
decltype
(
MakeElementwiseInputSequence
());
using
GridwiseElementwise
=
GridwiseElementwise_2D
<
CDGridDesc_M_N
,
// zmien na C, D_0, ..., D_n / tuple<C, D_0, ..., D_N>
ck
::
Tuple
<
EGridDesc_M_N
>
,
CDDataTypes
,
// zmien na C, D_0, ..., D_n / tuple<C, D_0, ..., D_N>
ck
::
Tuple
<
EDataType
*>
,
CDEElementwiseOperation
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// MPerThread
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// NPerThread
ElementwiseInputSequence
,
ck
::
Sequence
<
8
>>
;
using
Block2ETileMapKSplit
=
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
// Block2CTileMap configuration parameter.
// Block2CTileMap configuration parameter.
...
@@ -231,6 +323,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -231,6 +323,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
}
}
gemm_kernel_args_
.
reserve
(
group_count_
);
gemm_kernel_args_
.
reserve
(
group_count_
);
elementwise_c_grid_descs_m_n_
.
reserve
(
group_count_
);
elementwise_d_grid_descs_m_n_
.
reserve
(
group_count_
);
ds_grid_pointer_
.
reserve
(
group_count_
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
{
...
@@ -255,6 +352,17 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -255,6 +352,17 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_e
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_e
);
DsGridDesc_M_N
ds_grid_desc_m_n
;
DsGridPointer
p_ds_grid
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsDataType
>>
;
p_ds_grid
(
j
)
=
static_cast
<
const
DDataType
*>
(
p_Ds
[
i
][
j
]);
ds_grid_desc_m_n
(
j
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
M
,
N
,
gemm_descs
[
i
].
stride_Ds_
[
j
]);
});
const
auto
local_b2c_tile_map
=
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
...
@@ -284,7 +392,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -284,7 +392,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
// when workspace will be set, this will be updated to workspace memory.
// when workspace will be set, this will be updated to workspace memory.
auto
karg
=
GemmKernelArgument
{
type_convert
<
const
ADataType
*>
(
p_As
[
i
]),
auto
karg
=
GemmKernelArgument
{
type_convert
<
const
ADataType
*>
(
p_As
[
i
]),
type_convert
<
const
BDataType
*>
(
p_Bs
[
i
]),
type_convert
<
const
BDataType
*>
(
p_Bs
[
i
]),
type_convert
<
E
DataType
*>
(
p_Es
[
i
]),
type_convert
<
Workspace
DataType
*>
(
p_Es
[
i
]),
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -299,7 +407,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -299,7 +407,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
gemm_kernel_args_
.
emplace_back
(
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
std
::
move
(
grouped_block_2_ctile_map
),
block_start
,
block_end
);
std
::
move
(
karg
),
std
::
move
(
grouped_block_2_ctile_map
),
block_start
,
block_end
);
elementwise_c_grid_descs_m_n_
.
push_back
(
c_grid_desc_m_n
);
elementwise_d_grid_descs_m_n_
.
push_back
(
ds_grid_desc_m_n
);
ds_grid_pointer_
.
push_back
(
p_ds_grid
);
}
}
// Store a copy of E pointers for elementwise kernel destination
e_ptrs_
=
p_Es
;
}
}
/**
/**
...
@@ -374,7 +488,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -374,7 +488,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
for
(
const
auto
&
arg
:
gemm_kernel_args_
)
for
(
const
auto
&
arg
:
gemm_kernel_args_
)
{
{
index_t
tiles
=
(
arg
.
block_end_
-
arg
.
block_start_
)
/
arg
.
karg_
.
k_batch
;
index_t
tiles
=
(
arg
.
block_end_
-
arg
.
block_start_
)
/
arg
.
karg_
.
k_batch
;
size_bytes
+=
tiles
*
MPerBlock
*
NPerBlock
*
sizeof
(
Workspace
E
DataType
);
size_bytes
+=
tiles
*
MPerBlock
*
NPerBlock
*
sizeof
(
WorkspaceDataType
);
}
}
return
size_bytes
;
return
size_bytes
;
}
}
...
@@ -401,6 +515,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -401,6 +515,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds_
;
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds_
;
std
::
vector
<
std
::
array
<
index_t
,
NumDTensor
>>
stride_Ds_
;
std
::
vector
<
std
::
array
<
index_t
,
NumDTensor
>>
stride_Ds_
;
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
std
::
vector
<
CGridDesc_M_N
>
elementwise_c_grid_descs_m_n_
;
std
::
vector
<
DsGridDesc_M_N
>
elementwise_d_grid_descs_m_n_
;
std
::
vector
<
DsGridPointer
>
ds_grid_pointer_
;
std
::
vector
<
void
*>
e_ptrs_
;
};
};
// Invoker
// Invoker
...
@@ -591,13 +712,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -591,13 +712,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
PassThrough
>
;
PassThrough
>
;
// TODO
// const auto fuse_kernel = ...
const
auto
elementwise_kernel
=
kernel_elementwise_2d
<
GridwiseElementwise
,
return
LaunchKernel
(
gemm_kernel
,
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
CDGridDesc_M_N
,
ck
::
Tuple
<
EGridDesc_M_N
>
,
CDDataTypes
,
ck
::
Tuple
<
EDataType
*>
,
CDEElementwiseOperation
>
;
return
LaunchKernel
(
gemm_kernel
,
elementwise_kernel
,
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
}
}
template
<
typename
KernelFunction
>
template
<
typename
KernelFunction
,
typename
KernelFunction2
>
float
LaunchKernel
(
const
KernelFunction
&
gemm_kernel
,
float
LaunchKernel
(
const
KernelFunction
&
gemm_kernel
,
const
KernelFunction2
&
elementwise_kernel
,
const
Argument
&
arg
,
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
void
*
dev_gemm_args
,
[[
maybe_unused
]]
void
*
dev_gemm_workspace
,
[[
maybe_unused
]]
void
*
dev_gemm_workspace
,
...
@@ -624,7 +751,21 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -624,7 +751,21 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
arg
.
b_element_op_
,
arg
.
b_element_op_
,
PassThrough
{});
PassThrough
{});
// launch fuse kernel.
// launch elementwise kernels.
for
(
int
i
=
0
;
i
<
arg
.
group_count_
;
++
i
)
{
time
+=
launch_and_time_kernel
(
stream_config
,
elementwise_kernel
,
dim3
(
arg
.
grid_size_
),
// chyba group_grid_size <<< tak zmienic na group_grid_size[i]
dim3
(
BlockSize
),
0
,
concat_tuple
(
make_tuple
(
arg
.
elementwise_c_grid_descs_m_n_
[
i
]),
arg
.
elementwise_d_grid_descs_m_n_
[
i
]),
make_tuple
(
arg
.
elementwise_c_grid_descs_m_n_
[
i
]),
concat_tuple
(
make_tuple
(
arg
.
gemm_kernel_args_
[
i
].
karg_
.
p_c_grid
),
arg
.
ds_grid_pointer_
[
i
]),
type_convert
<
EDataType
*>
(
arg
.
e_ptrs_
[
i
]),
arg
.
cde_element_op_
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// num_threads_m
CDEShuffleBlockTransferScalarPerVector_NPerBlock
);
// num_threads_n
}
return
time
;
return
time
;
}
}
...
@@ -749,8 +890,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -749,8 +890,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
PipelineVer
<<
", "
<<
LoopSched
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp
0 → 100644
View file @
0ea428c3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
// assumption: every D matrix has the same layout and the same datatype
template
<
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
ComputeTypeA
=
ADataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
ReferenceGemmMultipleD
:
public
device
::
BaseOperator
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
0
,
DsDataType
>>
;
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
std
::
array
<
Tensor
<
DDataType
>
,
DsDataType
::
Size
()
>&
ds_m_n
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
a_m_k_
{
a_m_k
},
b_k_n_
{
b_k_n
},
ds_m_n_
{
ds_m_n
},
c_m_n_
{
c_m_n
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_m_k_
;
const
Tensor
<
BDataType
>&
b_k_n_
;
const
std
::
array
<
Tensor
<
DDataType
>
,
DsDataType
::
Size
()
>&
ds_m_n_
;
Tensor
<
CDataType
>&
c_m_n_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceGemmMultipleD
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_acc
=
0
;
ComputeTypeA
v_a
=
0
;
ComputeTypeB
v_b
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
else
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
// same for B matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
else
{
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
CDataType
v_c
=
0
;
if
constexpr
(
DsDataType
::
Size
()
==
0
)
{
arg
.
cde_element_op_
(
v_c
,
v_acc
);
}
else
if
constexpr
(
DsDataType
::
Size
()
==
1
)
{
arg
.
cde_element_op_
(
v_c
,
v_acc
,
arg
.
ds_m_n_
[
0
](
m
,
n
));
}
else
if
constexpr
(
DsDataType
::
Size
()
==
2
)
{
arg
.
cde_element_op_
(
v_c
,
v_acc
,
arg
.
ds_m_n_
[
0
](
m
,
n
),
arg
.
ds_m_n_
[
1
](
m
,
n
));
}
arg
.
c_m_n_
(
m
,
n
)
=
v_c
;
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
arg
.
c_m_n_
.
mDesc
.
GetLengths
()[
0
],
arg
.
c_m_n_
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
std
::
array
<
Tensor
<
DDataType
>
,
DsDataType
::
Size
()
>&
ds_m_n
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
a_m_k
,
b_k_n
,
ds_m_n
,
c_m_n
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceGemm"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
0ea428c3
...
@@ -146,6 +146,19 @@ void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(
...
@@ -146,6 +146,19 @@ void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
// void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances)
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
ELayout
,
typename
ELayout
,
...
@@ -190,6 +203,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -190,6 +203,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
op_ptrs
);
// add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
View file @
0ea428c3
...
@@ -9,4 +9,5 @@ add_instance_library(device_grouped_gemm_instance
...
@@ -9,4 +9,5 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
#device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
0 → 100644
View file @
0ea428c3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Instances having AK1!=BK1 are temporarily disabled and will be reenabled in future
// a[m, k] * b[k, n] = e[m, n]
using
device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>
// clang-format on
>
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// //DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// //DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// //DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// //DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>
;
void
add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_irregular_tile_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/src/profile_gemm_fixed_nk.cpp
0 → 100644
View file @
0ea428c3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment