Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6adf3591
Commit
6adf3591
authored
Jul 01, 2022
by
Jing Zhang
Browse files
add batch_stride
parent
fa9a0a5c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
48 additions
and
13 deletions
+48
-13
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+16
-4
profiler/include/profile_batched_gemm_impl.hpp
profiler/include/profile_batched_gemm_impl.hpp
+15
-6
profiler/src/profile_batched_gemm.cpp
profiler/src/profile_batched_gemm.cpp
+14
-3
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
View file @
6adf3591
...
@@ -26,6 +26,9 @@ struct DeviceBatchedGemm : public BaseOperator
...
@@ -26,6 +26,9 @@ struct DeviceBatchedGemm : public BaseOperator
ck
::
index_t
StrideA
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB
,
ck
::
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
6adf3591
...
@@ -334,6 +334,9 @@ struct DeviceBatchedGemmXdl
...
@@ -334,6 +334,9 @@ struct DeviceBatchedGemmXdl
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
,
index_t
M01
,
index_t
M01
,
index_t
N01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
...
@@ -350,10 +353,7 @@ struct DeviceBatchedGemmXdl
...
@@ -350,10 +353,7 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
)},
DeviceBatchedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceBatchedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
)},
c_grid_desc_m_n_
{
DeviceBatchedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
compute_ptr_offset_of_batch_
{
compute_ptr_offset_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideC
},
type_convert
<
index_t
>
(
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
c_grid_desc_m_n_
.
GetElementSpaceSize
())},
block_2_ctile_map_
{
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
)},
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
)},
M01_
{
M01
},
M01_
{
M01
},
...
@@ -536,6 +536,9 @@ struct DeviceBatchedGemmXdl
...
@@ -536,6 +536,9 @@ struct DeviceBatchedGemmXdl
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
...
@@ -550,6 +553,9 @@ struct DeviceBatchedGemmXdl
...
@@ -550,6 +553,9 @@ struct DeviceBatchedGemmXdl
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
1
,
1
,
1
,
1
,
a_element_op
,
a_element_op
,
...
@@ -570,6 +576,9 @@ struct DeviceBatchedGemmXdl
...
@@ -570,6 +576,9 @@ struct DeviceBatchedGemmXdl
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
...
@@ -584,6 +593,9 @@ struct DeviceBatchedGemmXdl
...
@@ -584,6 +593,9 @@ struct DeviceBatchedGemmXdl
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
1
,
1
,
1
,
1
,
a_element_op
,
a_element_op
,
...
...
profiler/include/profile_batched_gemm_impl.hpp
View file @
6adf3591
...
@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
...
@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
int
M
,
int
M
,
int
N
,
int
N
,
int
K
,
int
K
,
int
BatchStrideA
,
int
BatchStrideB
,
int
BatchStrideC
,
int
StrideA
,
int
StrideA
,
int
StrideB
,
int
StrideB
,
int
StrideC
,
int
StrideC
,
...
@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification,
...
@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification,
std
::
size_t
row
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
auto
layout
)
{
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
row
*
stride
,
stride
,
1
}));
std
::
vector
<
std
::
size_t
>
({
batch_
stride
,
stride
,
1
}));
}
}
else
else
{
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
col
*
stride
,
1
,
stride
}));
std
::
vector
<
std
::
size_t
>
({
batch_
stride
,
1
,
stride
}));
}
}
};
};
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_g_m_k
(
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB
,
BLayout
{}));
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA
,
BatchStrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB
,
BatchStrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_g_m_n_host_result
(
Tensor
<
CDataType
>
c_g_m_n_host_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
BatchStrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_g_m_n_device_result
(
Tensor
<
CDataType
>
c_g_m_n_device_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
BatchStrideC
,
CLayout
{}));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
mDesc
<<
std
::
endl
;
...
@@ -148,6 +154,9 @@ bool profile_batched_gemm_impl(int do_verification,
...
@@ -148,6 +154,9 @@ bool profile_batched_gemm_impl(int do_verification,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
...
...
profiler/src/profile_batched_gemm.cpp
View file @
6adf3591
...
@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[])
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
const
int
StrideA_
=
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
;
const
int
StrideB_
=
(
StrideB
<
0
)
?
DefaultStrideB
:
StrideB
;
const
int
StrideC_
=
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
;
const
int
BatchStrideA
=
(
ck
::
is_same_v
<
ALayout
,
Row
>
?
M
:
K
)
*
StrideA_
;
const
int
BatchStrideB
=
(
ck
::
is_same_v
<
BLayout
,
Row
>
?
K
:
N
)
*
StrideB_
;
const
int
BatchStrideC
=
(
ck
::
is_same_v
<
CLayout
,
Row
>
?
M
:
N
)
*
StrideC_
;
bool
pass
=
ck
::
profiler
::
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
do_verification
,
do_verification
,
...
@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[])
...
@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[])
M
,
M
,
N
,
N
,
K
,
K
,
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
,
BatchStrideA
,
(
StrideB
<
0
)
?
DefaultStrideB
:
StrideB
,
BatchStrideB
,
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
,
BatchStrideC
,
StrideA_
,
StrideB_
,
StrideC_
,
BatchCount
);
BatchCount
);
return
pass
?
0
:
1
;
return
pass
?
0
:
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment