Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
de6a70f7
Commit
de6a70f7
authored
Jul 19, 2022
by
Jing Zhang
Browse files
add ds
parent
1d11426a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
193 additions
and
158 deletions
+193
-158
example/28_batched_gemm_multi_d/batched_gemm_multi_d_xdl_fp16.cpp
...28_batched_gemm_multi_d/batched_gemm_multi_d_xdl_fp16.cpp
+24
-32
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
...nsor_operation/gpu/device/device_batched_gemm_multi_d.hpp
+14
-11
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp
..._operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp
+155
-115
No files found.
example/28_batched_gemm_multi_d/batched_gemm_multi_d_xdl_fp16.cpp
View file @
de6a70f7
...
@@ -33,9 +33,9 @@ using CShuffleDataType = F16;
...
@@ -33,9 +33,9 @@ using CShuffleDataType = F16;
using
DsDataType
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Col
;
using
ELayout
=
Row
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
...
@@ -63,9 +63,9 @@ int main(int argc, char* argv[])
...
@@ -63,9 +63,9 @@ int main(int argc, char* argv[])
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
const
int
M
=
256
;
const
int
M
=
256
*
(
rand
()
%
16
+
1
)
;
const
int
N
=
128
;
const
int
N
=
128
*
(
rand
()
%
16
+
1
)
;
const
int
K
=
64
;
const
int
K
=
64
*
(
rand
()
%
16
+
1
)
;
const
int
stride_A
=
K
;
const
int
stride_A
=
K
;
const
int
stride_B
=
K
;
const
int
stride_B
=
K
;
...
@@ -112,12 +112,12 @@ int main(int argc, char* argv[])
...
@@ -112,12 +112,12 @@ int main(int argc, char* argv[])
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
ALayout
{}));
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
ALayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
BLayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
BLayout
{}));
Tensor
<
EDataType
>
c
_g_m_n_device_result
(
Tensor
<
EDataType
>
e
_g_m_n_device_result
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
ELayout
{}));
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
ELayout
{}));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
c
_g_m_n: "
<<
c
_g_m_n_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
e
_g_m_n: "
<<
e
_g_m_n_device_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -134,35 +134,38 @@ int main(int argc, char* argv[])
...
@@ -134,35 +134,38 @@ int main(int argc, char* argv[])
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_g_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_g_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
EDataType
)
*
c
_g_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
EDataType
)
*
e
_g_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CDEElementOp
{};
auto
c
de
_element_op
=
CDEElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
// do GEMM
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
b_device_buf
.
GetDeviceBuffer
(),
static_cast
<
EDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
{},
c_device_buf
.
GetDeviceBuffer
(),
M
,
M
,
N
,
N
,
K
,
K
,
stride_A
,
stride_A
,
stride_B
,
stride_B
,
{},
stride_C
,
stride_C
,
batch_stride_A
,
batch_stride_A
,
batch_stride_B
,
batch_stride_B
,
{},
batch_stride_C
,
batch_stride_C
,
batch_count
,
batch_count
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c
de
_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -189,32 +192,21 @@ int main(int argc, char* argv[])
...
@@ -189,32 +192,21 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
c_device_buf
.
FromDevice
(
c
_g_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
e
_g_m_n_device_result
.
mData
.
data
());
auto
ref_batched_gemm
=
ReferenceBatchedGemmInstance
{};
auto
ref_batched_gemm
=
ReferenceBatchedGemmInstance
{};
auto
ref_invoker
=
ref_batched_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_batched_gemm
.
MakeInvoker
();
Tensor
<
EDataType
>
c
_g_m_n_host_result
=
HostTensorDescriptor
(
Tensor
<
EDataType
>
e
_g_m_n_host_result
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
M
,
N
})
,
st
d
::
vector
<
std
::
size_t
>
({
M
*
N
,
N
,
1
}));
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
st
ride_C
,
ELayout
{
}));
auto
ref_argument
=
ref_batched_gemm
.
MakeArgument
(
auto
ref_argument
=
ref_batched_gemm
.
MakeArgument
(
a_g_m_k
,
b_g_k_n
,
c
_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
a_g_m_k
,
b_g_k_n
,
e
_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c
de
_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
//for(int b = 0; b < batch_count; b++)
//{
//for(int m = 0; m < M; m++)
//{
//for(int n = 0; n < N; n++)
//{
//c_g_m_n_host_result(b, m, n) = c_g_m_n_host_result(b, m, n);
//}
//}
//}
pass
=
ck
::
utils
::
check_err
(
pass
=
ck
::
utils
::
check_err
(
c
_g_m_n_host_result
.
mData
,
c
_g_m_n_device_result
.
mData
,
"Error: Incorrect results c"
);
e
_g_m_n_host_result
.
mData
,
e
_g_m_n_device_result
.
mData
,
"Error: Incorrect results c"
);
}
}
return
pass
?
0
:
1
;
return
pass
?
0
:
1
;
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
View file @
de6a70f7
...
@@ -29,16 +29,18 @@ struct DeviceBatchedGemmMultiD : public BaseOperator
...
@@ -29,16 +29,18 @@ struct DeviceBatchedGemmMultiD : public BaseOperator
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideB
,
//
std::array<ck::index_t, NumDTensor> StrideDs,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
ck
::
index_t
StrideE
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB
,
ck
::
index_t
BatchStrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
ck
::
index_t
BatchStrideE
,
ck
::
index_t
BatchStrideE
,
ck
::
index_t
Batch
,
ck
::
index_t
Batch
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
...
@@ -58,16 +60,17 @@ template <typename ALayout,
...
@@ -58,16 +60,17 @@ template <typename ALayout,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
typename
CDEElementwiseOperation
>
using
DeviceBatchedGemmMultiDPtr
=
std
::
unique_ptr
<
DeviceBatchedGemmMultiD
<
ALayout
,
using
DeviceBatchedGemmMultiDPtr
=
BLayout
,
std
::
unique_ptr
<
DeviceBatchedGemmMultiD
<
ALayout
,
CLayout
,
BLayout
,
ADataType
,
CLayout
,
BDataType
,
ADataType
,
DsDataType
,
BDataType
,
EDataType
,
DsDataType
,
AElementwiseOperation
,
EDataType
,
BElementwiseOperation
,
AElementwiseOperation
,
CDEElementwiseOperation
>>
;
BElementwiseOperation
,
CDEElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp
View file @
de6a70f7
...
@@ -47,10 +47,12 @@ namespace device {
...
@@ -47,10 +47,12 @@ namespace device {
*/
*/
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatDsPointer
,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
...
@@ -63,18 +65,22 @@ __global__ void
...
@@ -63,18 +65,22 @@ __global__ void
#endif
#endif
kernel_batched_gemm_xdl
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_batched_gemm_xdl
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatC
*
__restrict__
p_e_grid
,
FloatC
*
__restrict__
p_e_grid
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
,
const
CDEElementwiseOperation
c
de
_element_op
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -84,38 +90,47 @@ __global__ void
...
@@ -84,38 +90,47 @@ __global__ void
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
FloatDsPointer
p_ds_grid_grp
;
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
static
constexpr
index_t
NumDTensor
=
ck
::
Tuple
<>
{},
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
p_e_grid
+
c_batch_offset
,
p_shared
,
static_for
<
0
,
NumDTensor
,
1
>
{}(
a_element_op
,
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
b_element_op
,
c_element_op
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
a_grid_desc_k0_m_k1
,
p_b_grid
+
b_batch_offset
,
b_grid_desc_k0_n_k1
,
p_ds_grid_grp
,
ck
::
StaticallyIndexedArray
<
p_e_grid
+
e_batch_offset
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
p_shared
,
0
>
{},
a_element_op
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
b_element_op
,
block_2_ctile_map
);
cde_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c
de
_element_op
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif
#endif
...
@@ -456,8 +471,12 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
...
@@ -456,8 +471,12 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
{
{
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideE_
(
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
{
}
}
...
@@ -471,7 +490,16 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
...
@@ -471,7 +490,16 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_offset
[
i
]
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
}
...
@@ -479,6 +507,7 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
...
@@ -479,6 +507,7 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
private:
private:
index_t
BatchStrideA_
;
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideB_
;
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
BatchStrideE_
;
};
};
...
@@ -535,41 +564,46 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
...
@@ -535,41 +564,46 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
ADataType
*
p_a_grid
,
Argument
(
const
void
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
void
*
p_b_grid
,
EDataType
*
p_e_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
index_t
StrideE
,
index_t
BatchStrideA
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
,
index_t
BatchStrideE
,
index_t
Batch
,
index_t
Batch
,
index_t
M01
,
index_t
M01
,
index_t
N01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
c_element_op
)
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_e_grid_
{
p_e_grid
},
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
Batch_
(
Batch
),
Batch_
(
Batch
),
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceBatchedGemmMultiDXdl
::
MakeAGridDescriptor_AK0_M_AK1
(
M
,
K
,
StrideA
)},
DeviceBatchedGemmMultiDXdl
::
MakeAGridDescriptor_AK0_M_AK1
(
M
,
K
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
DeviceBatchedGemmMultiDXdl
::
MakeBGridDescriptor_BK0_N_BK1
(
K
,
N
,
StrideB
)},
DeviceBatchedGemmMultiDXdl
::
MakeBGridDescriptor_BK0_N_BK1
(
K
,
N
,
StrideB
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_m_n_
{
DeviceBatchedGemmMultiDXdl
::
MakeEGridDescriptor_M_N
(
M
,
N
,
StrideE
)},
e_grid_desc_m_n_
{
DeviceBatchedGemmMultiDXdl
::
MakeEGridDescriptor_M_N
(
M
,
N
,
StrideE
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
compute_ptr_offset_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideE
},
compute_ptr_offset_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideDs
,
BatchStrideE
},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
c_element_op
}
cde_element_op_
{
c
de
_element_op
}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
...
@@ -579,6 +613,19 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
...
@@ -579,6 +613,19 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
e_grid_desc_m_n_
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
const
auto
d_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
M
,
N
,
StrideDs
[
i
]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
i
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
d_grid_desc_m_n
);
});
}
}
}
}
...
@@ -646,77 +693,57 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
...
@@ -646,77 +693,57 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_batched_gemm_xdl
<
const
auto
kernel
=
kernel_batched_gemm_xdl
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
EDataType
,
remove_reference_t
<
DeviceBatchedGemmMultiDXdl
::
AGridDesc_AK0_M_AK1
>
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
remove_reference_t
<
DeviceBatchedGemmMultiDXdl
::
BGridDesc_BK0_N_BK1
>
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
remove_reference_t
<
Block2CTileMap
>
,
true
>
;
has_main_loop
>
;
ave_time
=
return
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
Batch_
,
arg
.
Batch_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
b_element_op_
,
arg
.
a_element_op_
,
arg
.
cde_element_op_
,
arg
.
b_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
};
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
}
else
else
{
{
const
auto
kernel
=
kernel_batched_gemm_xdl
<
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
EDataType
,
remove_reference_t
<
DeviceBatchedGemmMultiDXdl
::
AGridDesc_AK0_M_AK1
>
,
remove_reference_t
<
DeviceBatchedGemmMultiDXdl
::
BGridDesc_BK0_N_BK1
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_e_grid_
,
arg
.
Batch_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -750,81 +777,94 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
...
@@ -750,81 +777,94 @@ struct DeviceBatchedGemmMultiDXdl : public DeviceBatchedGemmMultiD<ALayout,
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
const
void
*
p_a
,
const
BDataType
*
p_b
,
const
void
*
p_b
,
EDataType
*
p_c
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
index_t
StrideE
,
index_t
BatchStrideA
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
,
index_t
BatchStrideE
,
index_t
Batch
,
index_t
Batch
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
c_element_op
)
CDEElementwiseOperation
c
de
_element_op
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
p_ds
,
p_c
,
p_c
,
M
,
M
,
N
,
N
,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideDs
,
StrideE
,
StrideE
,
BatchStrideA
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB
,
BatchStrideDs
,
BatchStrideE
,
BatchStrideE
,
Batch
,
Batch
,
1
,
1
,
1
,
1
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
};
c
de
_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
const
void
*
p_b
,
MakeArgumentPointer
(
const
void
*
p_a
,
void
*
p_c
,
const
void
*
p_b
,
index_t
M
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
index_t
N
,
void
*
p_c
,
index_t
K
,
index_t
M
,
index_t
StrideA
,
index_t
N
,
index_t
StrideB
,
index_t
K
,
index_t
StrideE
,
index_t
StrideA
,
index_t
BatchStrideA
,
index_t
StrideB
,
index_t
BatchStrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
BatchStrideE
,
index_t
StrideE
,
index_t
Batch
,
index_t
BatchStrideA
,
AElementwiseOperation
a_element_op
,
index_t
BatchStrideB
,
BElementwiseOperation
b_element_op
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
CDEElementwiseOperation
c_element_op
)
override
index_t
BatchStrideE
,
index_t
Batch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
p_a
,
static_cast
<
const
BDataType
*>
(
p_b
),
p_b
,
static_cast
<
EDataType
*>
(
p_c
),
p_ds
,
p_c
,
M
,
M
,
N
,
N
,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideDs
,
StrideE
,
StrideE
,
BatchStrideA
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB
,
BatchStrideDs
,
BatchStrideE
,
BatchStrideE
,
Batch
,
Batch
,
1
,
1
,
1
,
1
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c
de
_element_op
);
}
}
// polymorphic
// polymorphic
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment