Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
de6a70f7
Commit
de6a70f7
authored
Jul 19, 2022
by
Jing Zhang
Browse files
add ds
parent
1d11426a
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
193 additions
and
158 deletions
+193
-158
example/28_batched_gemm_multi_d/batched_gemm_multi_d_xdl_fp16.cpp
...28_batched_gemm_multi_d/batched_gemm_multi_d_xdl_fp16.cpp
+24
-32
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
...nsor_operation/gpu/device/device_batched_gemm_multi_d.hpp
+14
-11
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp
..._operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp
+155
-115
No files found.
example/28_batched_gemm_multi_d/batched_gemm_multi_d_xdl_fp16.cpp
View file @
de6a70f7
...
...
@@ -33,9 +33,9 @@ using CShuffleDataType = F16;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
ELayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
...
...
@@ -63,9 +63,9 @@ int main(int argc, char* argv[])
int
init_method
=
1
;
bool
time_kernel
=
false
;
const
int
M
=
256
;
const
int
N
=
128
;
const
int
K
=
64
;
const
int
M
=
256
*
(
rand
()
%
16
+
1
)
;
const
int
N
=
128
*
(
rand
()
%
16
+
1
)
;
const
int
K
=
64
*
(
rand
()
%
16
+
1
)
;
const
int
stride_A
=
K
;
const
int
stride_B
=
K
;
...
...
@@ -112,12 +112,12 @@ int main(int argc, char* argv[])
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
ALayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
BLayout
{}));
Tensor
<
EDataType
>
c
_g_m_n_device_result
(
Tensor
<
EDataType
>
e
_g_m_n_device_result
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
ELayout
{}));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
c
_g_m_n: "
<<
c
_g_m_n_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
e
_g_m_n: "
<<
e
_g_m_n_device_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -134,35 +134,38 @@ int main(int argc, char* argv[])
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_g_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
EDataType
)
*
c
_g_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
EDataType
)
*
e
_g_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CDEElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c
de
_element_op
=
CDEElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
EDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
{},
c_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
stride_A
,
stride_B
,
{},
stride_C
,
batch_stride_A
,
batch_stride_B
,
{},
batch_stride_C
,
batch_count
,
a_element_op
,
b_element_op
,
c_element_op
);
c
de
_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -189,32 +192,21 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c
_g_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
e
_g_m_n_device_result
.
mData
.
data
());
auto
ref_batched_gemm
=
ReferenceBatchedGemmInstance
{};
auto
ref_invoker
=
ref_batched_gemm
.
MakeInvoker
();
Tensor
<
EDataType
>
c
_g_m_n_host_result
=
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
M
,
N
})
,
st
d
::
vector
<
std
::
size_t
>
({
M
*
N
,
N
,
1
}));
Tensor
<
EDataType
>
e
_g_m_n_host_result
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
st
ride_C
,
ELayout
{
}));
auto
ref_argument
=
ref_batched_gemm
.
MakeArgument
(
a_g_m_k
,
b_g_k_n
,
c
_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
a_g_m_k
,
b_g_k_n
,
e
_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c
de
_element_op
);
ref_invoker
.
Run
(
ref_argument
);
//for(int b = 0; b < batch_count; b++)
//{
//for(int m = 0; m < M; m++)
//{
//for(int n = 0; n < N; n++)
//{
//c_g_m_n_host_result(b, m, n) = c_g_m_n_host_result(b, m, n);
//}
//}
//}
pass
=
ck
::
utils
::
check_err
(
c
_g_m_n_host_result
.
mData
,
c
_g_m_n_device_result
.
mData
,
"Error: Incorrect results c"
);
e
_g_m_n_host_result
.
mData
,
e
_g_m_n_device_result
.
mData
,
"Error: Incorrect results c"
);
}
return
pass
?
0
:
1
;
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
View file @
de6a70f7
...
...
@@ -29,16 +29,18 @@ struct DeviceBatchedGemmMultiD : public BaseOperator
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
//
std::array<ck::index_t, NumDTensor> StrideDs,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
ck
::
index_t
BatchStrideE
,
ck
::
index_t
Batch
,
AElementwiseOperation
a_element_op
,
...
...
@@ -58,16 +60,17 @@ template <typename ALayout,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
using
DeviceBatchedGemmMultiDPtr
=
std
::
unique_ptr
<
DeviceBatchedGemmMultiD
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>>
;
using
DeviceBatchedGemmMultiDPtr
=
std
::
unique_ptr
<
DeviceBatchedGemmMultiD
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp
View file @
de6a70f7
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment