Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5f94555b
Commit
5f94555b
authored
Aug 04, 2022
by
Anthony Chang
Browse files
implement proper interface
parent
98e4c0ce
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
229 additions
and
66 deletions
+229
-66
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
+72
-28
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
.../tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
+82
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
...tion/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
+73
-38
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+2
-0
No files found.
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
View file @
5f94555b
...
@@ -49,23 +49,26 @@ using B1Layout = Row;
...
@@ -49,23 +49,26 @@ using B1Layout = Row;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmGemm_Xdl_CShuffle
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
Device
Batched
GemmGemm_Xdl_CShuffle
<
ALayout
,
ALayout
,
B0Layout
,
B0Layout
,
B1Layout
,
B1Layout
,
CLayout
,
CLayout
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
CDataType
,
CDataType
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
B0ElementOp
,
B1ElementOp
,
CElementOp
,
CElementOp
,
GemmDefault
,
GemmDefault
,
1
,
1
,
...
@@ -114,10 +117,10 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -114,10 +117,10 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
ADataType
,
ADataType
,
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
B
0
ElementOp
,
CElementOp
>
;
CElementOp
>
;
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
AccDataType
,
AElementOp
,
B
1
ElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -130,7 +133,7 @@ int main(int argc, char* argv[])
...
@@ -130,7 +133,7 @@ int main(int argc, char* argv[])
// ck::index_t N = 1024;
// ck::index_t N = 1024;
// ck::index_t K = 64;
// ck::index_t K = 64;
// ck::index_t O = 64;
// ck::index_t O = 64;
// ck::index_t BatchCount = 4;
// ck::index_t StrideA = 1024;
// ck::index_t StrideA = 1024;
// ck::index_t StrideB0 = 1024;
// ck::index_t StrideB0 = 1024;
// ck::index_t StrideB1 = 1024;
// ck::index_t StrideB1 = 1024;
...
@@ -140,11 +143,15 @@ int main(int argc, char* argv[])
...
@@ -140,11 +143,15 @@ int main(int argc, char* argv[])
ck
::
index_t
N
=
128
;
ck
::
index_t
N
=
128
;
ck
::
index_t
K
=
32
;
ck
::
index_t
K
=
32
;
ck
::
index_t
O
=
128
;
ck
::
index_t
O
=
128
;
ck
::
index_t
BatchCount
=
4
;
ck
::
index_t
StrideA
=
32
;
ck
::
index_t
StrideA
=
32
;
ck
::
index_t
StrideB0
=
32
;
ck
::
index_t
StrideB0
=
32
;
ck
::
index_t
StrideB1
=
128
;
ck
::
index_t
StrideB1
=
128
;
ck
::
index_t
StrideC
=
128
;
ck
::
index_t
StrideC
=
128
;
ck
::
index_t
BatchCount
=
64
;
ck
::
index_t
BatchStrideA
=
-
1
;
ck
::
index_t
BatchStrideB0
=
-
1
;
ck
::
index_t
BatchStrideB1
=
-
1
;
ck
::
index_t
BatchStrideC
=
-
1
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
@@ -156,7 +163,7 @@ int main(int argc, char* argv[])
...
@@ -156,7 +163,7 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
7
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -167,12 +174,17 @@ int main(int argc, char* argv[])
...
@@ -167,12 +174,17 @@ int main(int argc, char* argv[])
K
=
std
::
stoi
(
argv
[
6
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
O
=
std
::
stoi
(
argv
[
7
]);
StrideA
=
std
::
stoi
(
argv
[
8
]);
BatchCount
=
std
::
stoi
(
argv
[
8
]);
StrideB0
=
std
::
stoi
(
argv
[
9
]);
StrideB1
=
std
::
stoi
(
argv
[
10
]);
StrideA
=
std
::
stoi
(
argv
[
9
]);
StrideC
=
std
::
stoi
(
argv
[
11
]);
StrideB0
=
std
::
stoi
(
argv
[
10
]);
StrideB1
=
std
::
stoi
(
argv
[
11
]);
StrideC
=
std
::
stoi
(
argv
[
12
]);
BatchCount
=
std
::
stoi
(
argv
[
12
]);
BatchStrideA
=
std
::
stoi
(
argv
[
13
]);
BatchStrideB0
=
std
::
stoi
(
argv
[
14
]);
BatchStrideB1
=
std
::
stoi
(
argv
[
15
]);
BatchStrideC
=
std
::
stoi
(
argv
[
16
]);
}
}
else
else
{
{
...
@@ -183,29 +195,55 @@ int main(int argc, char* argv[])
...
@@ -183,29 +195,55 @@ int main(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB1
=
ck
::
is_same_v
<
B1Layout
,
Row
>
?
O
:
N
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
O
:
M
;
StrideA
=
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
;
StrideB0
=
(
StrideB0
<
0
)
?
DefaultStrideB0
:
StrideB0
;
StrideB1
=
(
StrideB1
<
0
)
?
DefaultStrideB1
:
StrideB1
;
StrideC
=
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
;
const
int
DefaultBatchStrideA
=
(
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
)
*
StrideA
;
const
int
DefaultBatchStrideB0
=
(
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
)
*
StrideB0
;
const
int
DefaultBatchStrideB1
=
(
ck
::
is_same_v
<
B1Layout
,
Row
>
?
O
:
N
)
*
StrideB1
;
const
int
DefaultBatchStrideC
=
(
ck
::
is_same_v
<
CLayout
,
Row
>
?
O
:
M
)
*
StrideC
;
BatchStrideA
=
BatchStrideA
<
0
?
DefaultBatchStrideA
:
BatchStrideA
;
BatchStrideB0
=
BatchStrideB0
<
0
?
DefaultBatchStrideB0
:
BatchStrideB0
;
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
BatchStrideC
=
BatchStrideC
<
0
?
DefaultBatchStrideC
:
BatchStrideC
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
(
std
::
is_same
<
decltype
(
layout
),
Row
>::
value
)
{
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
row
*
stride
,
stride
,
1
}));
std
::
vector
<
std
::
size_t
>
({
batch_
stride
,
stride
,
1
}));
}
}
else
else
{
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
col
*
stride
,
1
,
stride
}));
std
::
vector
<
std
::
size_t
>
({
batch_
stride
,
1
,
stride
}));
}
}
};
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
Tensor
<
B0DataType
>
b0_k_n
(
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB0
,
B0Layout
{}));
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA
,
BatchStrideA
,
ALayout
{}));
Tensor
<
B1DataType
>
b1_n_o
(
f_host_tensor_descriptor
(
BatchCount
,
N
,
O
,
StrideB1
,
B1Layout
{}));
Tensor
<
B0DataType
>
b0_k_n
(
Tensor
<
CDataType
>
c_m_o_host_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB0
,
BatchStrideB0
,
B0Layout
{}));
Tensor
<
CDataType
>
c_m_o_device_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideC
,
CLayout
{}));
Tensor
<
B1DataType
>
b1_n_o
(
f_host_tensor_descriptor
(
BatchCount
,
N
,
O
,
StrideB1
,
BatchStrideB1
,
B1Layout
{}));
Tensor
<
CDataType
>
c_m_o_host_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideC
,
BatchStrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_o_device_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideC
,
BatchStrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_k_n: "
<<
b0_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_k_n: "
<<
b0_k_n
.
mDesc
<<
std
::
endl
;
...
@@ -241,7 +279,8 @@ int main(int argc, char* argv[])
...
@@ -241,7 +279,8 @@ int main(int argc, char* argv[])
b1_n_o_device_buf
.
ToDevice
(
b1_n_o
.
mData
.
data
());
b1_n_o_device_buf
.
ToDevice
(
b1_n_o
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
// do GEMM
...
@@ -255,14 +294,19 @@ int main(int argc, char* argv[])
...
@@ -255,14 +294,19 @@ int main(int argc, char* argv[])
N
,
N
,
K
,
K
,
O
,
O
,
BatchCount
,
StrideA
,
StrideA
,
StrideB0
,
StrideB0
,
StrideB1
,
StrideB1
,
StrideC
,
StrideC
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b
0
_element_op
,
c
_element_op
,
b1
_element_op
,
BatchCount
);
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -290,19 +334,19 @@ int main(int argc, char* argv[])
...
@@ -290,19 +334,19 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
// Output of Gemm0 is input A of Gemm1
// Output of Gemm0 is input A of Gemm1
Tensor
<
ADataType
>
a1_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
Row
{}));
Tensor
<
ADataType
>
a1_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
a_m_k
,
b0_k_n
,
a1_m_n
,
a_element_op
,
b_element_op
,
c_element_op
);
a_m_k
,
b0_k_n
,
a1_m_n
,
a_element_op
,
b
0
_element_op
,
PassThrough
{}
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
a1_m_n
,
b1_n_o
,
c_m_o_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
a1_m_n
,
b1_n_o
,
c_m_o_host_result
,
PassThrough
{}
,
b
1
_element_op
,
c_element_op
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
0 → 100644
View file @
5f94555b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CLayout
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemmGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b1
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
O
,
ck
::
index_t
Batch
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB0
,
ck
::
index_t
StrideB1
,
ck
::
index_t
StrideC
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB0
,
ck
::
index_t
BatchStrideB1
,
ck
::
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CLayout
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceBatchedGemmGemmPtr
=
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
5f94555b
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_
batched_gemm_
gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/device_prop.hpp"
...
@@ -25,6 +25,7 @@ template <typename GridwiseGemm,
...
@@ -25,6 +25,7 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
...
@@ -43,6 +44,7 @@ __global__ void
...
@@ -43,6 +44,7 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
...
@@ -68,14 +70,6 @@ __global__ void
...
@@ -68,14 +70,6 @@ __global__ void
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
// if(threadIdx.x == 0)
// printf("bid = %zd, offset a b c d = %zd, %zd, %zd, %zd\n",
// hipBlockIdx_x,
// a_batch_offset,
// b_batch_offset,
// b1_batch_offset,
// c_batch_offset);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
...
@@ -83,6 +77,7 @@ __global__ void
...
@@ -83,6 +77,7 @@ __global__ void
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
...
@@ -92,20 +87,21 @@ __global__ void
...
@@ -92,20 +87,21 @@ __global__ void
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
p_shared
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
// ^^^^^^^^^^^ (Acc1)
...
@@ -114,12 +110,14 @@ template <typename ALayout,
...
@@ -114,12 +110,14 @@ template <typename ALayout,
typename
B1Layout
,
typename
B1Layout
,
typename
CLayout
,
typename
CLayout
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
// NOTE: don't distinguish B0/B1 type just yet
typename
BDataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
// NOTE: don't distinguish B0/B1 type just yet
typename
BElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
...
@@ -163,9 +161,20 @@ template <typename ALayout,
...
@@ -163,9 +161,20 @@ template <typename ALayout,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGemmGemm_Xdl_CShuffle
:
public
BaseOperator
// TODO ANT: inherit from DeviceGemmGemm subtype
struct
DeviceBatchedGemmGemm_Xdl_CShuffle
:
public
DeviceBatchedGemmGemm
<
ALayout
,
BLayout
,
B1Layout
,
CLayout
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGemmGemm_Xdl_CShuffle
;
using
DeviceOp
=
Device
Batched
GemmGemm_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -526,6 +535,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -526,6 +535,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
CDataType
,
CDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
...
@@ -582,20 +592,25 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -582,20 +592,25 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
{
{
Argument
(
const
ADataType
*
p_a_grid
,
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b1_grid
,
const
B
1
DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
// = ORaw
index_t
Gemm1NRaw
,
// = ORaw
index_t
Batch
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
C
ElementwiseOperation
c
_element_op
,
B1
ElementwiseOperation
b1
_element_op
,
index_t
Batch
)
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
...
@@ -609,13 +624,10 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -609,13 +624,10 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
}
type_convert
<
index_t
>
(
a_grid_desc_ak0_m_ak1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
b_grid_desc_bk0_n_bk1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
b1_grid_desc_bk0_n_bk1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
c_grid_desc_m_n_
.
GetElementSpaceSize
())}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
...
@@ -632,7 +644,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -632,7 +644,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
// private:
// private:
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b1_grid_
;
const
B
1
DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
...
@@ -643,6 +655,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -643,6 +655,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
batch_count_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
...
@@ -680,6 +693,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -680,6 +693,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
CDataType
,
CDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
...
@@ -700,6 +714,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -700,6 +714,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
...
@@ -760,20 +775,25 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -760,20 +775,25 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
BDataType
*
p_b
,
const
BDataType
*
p_b1
,
const
B
1
DataType
*
p_b1
,
CDataType
*
p_c
,
CDataType
*
p_c
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
index_t
Gemm1NRaw
,
index_t
Batch
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
C
ElementwiseOperation
c
_element_op
,
B1
ElementwiseOperation
b1
_element_op
,
index_t
Batch
)
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
...
@@ -783,14 +803,19 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -783,14 +803,19 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
NRaw
,
NRaw
,
KRaw
,
KRaw
,
Gemm1NRaw
,
Gemm1NRaw
,
Batch
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideB1
,
StrideB1
,
StrideC
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c
_element_op
,
b1
_element_op
,
Batch
};
c_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -804,35 +829,45 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -804,35 +829,45 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
index_t
Gemm1NRaw
,
index_t
Batch
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
C
ElementwiseOperation
c
_element_op
,
B1
ElementwiseOperation
b1
_element_op
,
index_t
Batch
)
/*
override
*/
CElementwiseOperation
c_element_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b1
),
static_cast
<
const
B
1
DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
MRaw
,
MRaw
,
NRaw
,
NRaw
,
KRaw
,
KRaw
,
Gemm1NRaw
,
Gemm1NRaw
,
Batch
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideB1
,
StrideB1
,
StrideC
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c
_element_op
,
b1
_element_op
,
Batch
);
c_element_op
);
}
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
/*
override
*/
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
}
...
@@ -843,7 +878,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -843,7 +878,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGemmGemm_Xdl_CShuffle"
str
<<
"Device
Batched
GemmGemm_Xdl_CShuffle"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
@@ -851,7 +886,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -851,7 +886,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
BK1
<<
", "
<<
N
PerBlock
<<
", "
<<
M
PerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
B1K1
<<
">"
;
<<
B1K1
<<
">"
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
5f94555b
...
@@ -23,6 +23,7 @@ template <typename FloatAB,
...
@@ -23,6 +23,7 @@ template <typename FloatAB,
typename
FloatC
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
...
@@ -316,6 +317,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -316,6 +317,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment