Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
426abafe
Commit
426abafe
authored
May 31, 2022
by
Jing Zhang
Browse files
init desc
parent
88578483
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1004 additions
and
67 deletions
+1004
-67
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+1
-0
example/15_grouped_gemm/grouped_gemm_transpose_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_transpose_xdl_fp16.cpp
+240
-0
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+12
-13
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+0
-29
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+76
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_transpose_xdl.hpp
...peration/gpu/device/device_grouped_gemm_transpose_xdl.hpp
+650
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+14
-14
test/grouped_gemm/grouped_gemm_fp16.cpp
test/grouped_gemm/grouped_gemm_fp16.cpp
+11
-11
No files found.
example/15_grouped_gemm/CMakeLists.txt
View file @
426abafe
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_transpose_xdl_fp16 grouped_gemm_transpose_xdl_fp16.cpp
)
example/15_grouped_gemm/grouped_gemm_transpose_xdl_fp16.cpp
0 → 100644
View file @
426abafe
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_grouped_gemm_transpose_xdl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// static constexpr auto GemmMNPadding =
// ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmTransposeXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
,
1
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
exit
(
0
);
}
int
group_count
=
rand
()
%
16
+
1
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmTransposeDesc
>
gemm_descs
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
gemm_descs
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
1024
;
int
N
=
1024
;
int
K
=
1024
;
gemm_descs
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
});
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_device_tensors
;
a_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
c_host_tensors
.
reserve
(
group_count
);
c_device_tensors
.
reserve
(
group_count
);
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
c_tensors_device
;
a_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
StrideA
,
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
K
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
StrideB
,
BLayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
StrideC
,
CLayout
{})));
c_device_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_descs
[
i
].
M
,
gemm_descs
[
i
].
N
,
gemm_descs
[
i
].
StrideC
,
CLayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_descs
[
i
].
M
*
gemm_descs
[
i
].
K
*
gemm_descs
[
i
].
N
;
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
case
2
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
}
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpace
()));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSpace
()));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSpace
()));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
}
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_tensors
[
i
],
b_tensors
[
i
],
c_host_tensors
[
i
],
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
pass
&=
ck
::
utils
::
check_err
(
c_device_tensors
[
i
].
mData
,
c_host_tensors
[
i
].
mData
);
}
}
return
pass
?
0
:
1
;
}
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
426abafe
...
@@ -81,11 +81,11 @@ int main(int argc, char* argv[])
...
@@ -81,11 +81,11 @@ int main(int argc, char* argv[])
int
group_count
=
rand
()
%
16
+
1
;
int
group_count
=
rand
()
%
16
+
1
;
// GEMM shape
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Shape
>
gemm_
shape
s
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Desc
>
gemm_
desc
s
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_c
;
gemm_
shape
s
.
reserve
(
group_count
);
gemm_
desc
s
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
{
...
@@ -93,7 +93,7 @@ int main(int argc, char* argv[])
...
@@ -93,7 +93,7 @@ int main(int argc, char* argv[])
int
N
=
128
+
128
*
i
;
int
N
=
128
+
128
*
i
;
int
K
=
64
+
64
*
i
;
int
K
=
64
+
64
*
i
;
gemm_
shape
s
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
});
gemm_
desc
s
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
});
}
}
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
...
@@ -111,7 +111,6 @@ int main(int argc, char* argv[])
...
@@ -111,7 +111,6 @@ int main(int argc, char* argv[])
};
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_device_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_device_tensors
;
...
@@ -131,22 +130,22 @@ int main(int argc, char* argv[])
...
@@ -131,22 +130,22 @@ int main(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_btype
=
0
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
K
,
gemm_
shape
s
[
i
].
StrideA
,
ALayout
{})));
gemm_
desc
s
[
i
].
M
,
gemm_
desc
s
[
i
].
K
,
gemm_
desc
s
[
i
].
StrideA
,
ALayout
{})));
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
b_tensors
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
K
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
StrideB
,
BLayout
{})));
gemm_
desc
s
[
i
].
K
,
gemm_
desc
s
[
i
].
N
,
gemm_
desc
s
[
i
].
StrideB
,
BLayout
{})));
c_host_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
c_host_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
StrideC
,
CLayout
{})));
gemm_
desc
s
[
i
].
M
,
gemm_
desc
s
[
i
].
N
,
gemm_
desc
s
[
i
].
StrideC
,
CLayout
{})));
c_device_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
c_device_tensors
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
StrideC
,
CLayout
{})));
gemm_
desc
s
[
i
].
M
,
gemm_
desc
s
[
i
].
N
,
gemm_
desc
s
[
i
].
StrideC
,
CLayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_
shape
s
[
i
].
M
*
gemm_
shape
s
[
i
].
K
*
gemm_
shape
s
[
i
].
N
;
flop
+=
std
::
size_t
(
2
)
*
gemm_
desc
s
[
i
].
M
*
gemm_
desc
s
[
i
].
K
*
gemm_
desc
s
[
i
].
N
;
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
...
@@ -168,7 +167,7 @@ int main(int argc, char* argv[])
...
@@ -168,7 +167,7 @@ int main(int argc, char* argv[])
}
}
}
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
a_tensors_device
.
emplace_back
(
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpace
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpace
()));
...
@@ -194,7 +193,7 @@ int main(int argc, char* argv[])
...
@@ -194,7 +193,7 @@ int main(int argc, char* argv[])
// do GEMM
// do GEMM
auto
argument
=
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_
shape
s
,
a_element_op
,
b_element_op
,
c_element_op
);
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_
desc
s
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
@@ -219,7 +218,7 @@ int main(int argc, char* argv[])
...
@@ -219,7 +218,7 @@ int main(int argc, char* argv[])
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
...
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
426abafe
...
@@ -8,12 +8,6 @@ namespace ck {
...
@@ -8,12 +8,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
struct
GemmShape
{
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
};
template
<
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
...
@@ -42,29 +36,6 @@ template <typename AElementwiseOperation,
...
@@ -42,29 +36,6 @@ template <typename AElementwiseOperation,
using
DeviceGemmPtr
=
std
::
unique_ptr
<
using
DeviceGemmPtr
=
std
::
unique_ptr
<
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGroupedGemmPtr
=
std
::
unique_ptr
<
DeviceGroupedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
0 → 100644
View file @
426abafe
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
GemmDesc
{
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
};
struct
GemmTransposeDesc
{
ck
::
index_t
M
,
N
,
K
;;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
ck
::
index_t
B
,
S
,
NumHead
,
HeadDim
;
std
::
vector
<
ck
::
index_t
>
transpose
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmDesc
>&
gemm_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGroupedGemmPtr
=
std
::
unique_ptr
<
DeviceGroupedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemmTranspose
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmTransposeDesc
>&
gemm_transpose_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGroupedGemmTransposePtr
=
std
::
unique_ptr
<
DeviceGroupedGemmTranspose
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm_transpose_xdl.hpp
0 → 100644
View file @
426abafe
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
426abafe
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <sstream>
#include <sstream>
#include "device.hpp"
#include "device.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
#include "device_gemm.hpp"
#include "device_
grouped_
gemm.hpp"
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
...
@@ -349,7 +349,7 @@ struct DeviceGroupedGemmXdl
...
@@ -349,7 +349,7 @@ struct DeviceGroupedGemmXdl
Argument
(
std
::
vector
<
const
void
*>&
p_a
,
Argument
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
Gemm
Shape
>&
gemm_
shape
s
,
std
::
vector
<
Gemm
Desc
>&
gemm_
desc
s
,
index_t
M01
,
index_t
M01
,
index_t
N01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
...
@@ -365,7 +365,7 @@ struct DeviceGroupedGemmXdl
...
@@ -365,7 +365,7 @@ struct DeviceGroupedGemmXdl
gemm_descs_args_workspace_
=
nullptr
;
gemm_descs_args_workspace_
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_
shape
s
.
size
());
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_
desc
s
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_a
.
size
())
&&
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_a
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_b
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_b
.
size
())
&&
...
@@ -376,15 +376,15 @@ struct DeviceGroupedGemmXdl
...
@@ -376,15 +376,15 @@ struct DeviceGroupedGemmXdl
gemm_desc_kernel_arg_
.
reserve
(
group_count_
);
gemm_desc_kernel_arg_
.
reserve
(
group_count_
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
const
index_t
M
=
gemm_
shape
s
[
i
].
M
;
const
index_t
M
=
gemm_
desc
s
[
i
].
M
;
const
index_t
N
=
gemm_
shape
s
[
i
].
N
;
const
index_t
N
=
gemm_
desc
s
[
i
].
N
;
const
index_t
K
=
gemm_
shape
s
[
i
].
K
;
const
index_t
K
=
gemm_
desc
s
[
i
].
K
;
const
index_t
StrideA
=
gemm_
shape
s
[
i
].
StrideA
;
const
index_t
StrideA
=
gemm_
desc
s
[
i
].
StrideA
;
const
index_t
StrideB
=
gemm_
shape
s
[
i
].
StrideB
;
const
index_t
StrideB
=
gemm_
desc
s
[
i
].
StrideB
;
const
index_t
StrideC
=
gemm_
shape
s
[
i
].
StrideC
;
const
index_t
StrideC
=
gemm_
desc
s
[
i
].
StrideC
;
const
auto
a_grid_desc_k0_m_k1_
=
const
auto
a_grid_desc_k0_m_k1_
=
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
...
@@ -580,12 +580,12 @@ struct DeviceGroupedGemmXdl
...
@@ -580,12 +580,12 @@ struct DeviceGroupedGemmXdl
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_a
,
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
Gemm
Shape
>
gemm_
shape
s
,
std
::
vector
<
Gemm
Desc
>
gemm_
desc
s
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_
shape
s
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_
desc
s
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -594,14 +594,14 @@ struct DeviceGroupedGemmXdl
...
@@ -594,14 +594,14 @@ struct DeviceGroupedGemmXdl
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
Gemm
Shape
>&
gemm_
shape
s
,
std
::
vector
<
Gemm
Desc
>&
gemm_
desc
s
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
index_t
/* KBatch */
=
1
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_c
,
gemm_
shape
s
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
p_a
,
p_b
,
p_c
,
gemm_
desc
s
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
}
// polymorphic
// polymorphic
...
...
test/grouped_gemm/grouped_gemm_fp16.cpp
View file @
426abafe
...
@@ -52,11 +52,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
...
@@ -52,11 +52,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
int
group_count
=
rand
()
%
10
+
1
;
int
group_count
=
rand
()
%
10
+
1
;
// GEMM shape
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Shape
>
gemm_
shape
s
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
Gemm
Desc
>
gemm_
desc
s
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_c
;
gemm_
shape
s
.
reserve
(
group_count
);
gemm_
desc
s
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
{
...
@@ -68,7 +68,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
...
@@ -68,7 +68,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
int
BStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
?
N
:
K
;
int
BStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
?
N
:
K
;
int
CStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
?
N
:
M
;
int
CStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
?
N
:
M
;
gemm_
shape
s
.
push_back
({
M
,
N
,
K
,
AStride
,
BStride
,
CStride
});
gemm_
desc
s
.
push_back
({
M
,
N
,
K
,
AStride
,
BStride
,
CStride
});
}
}
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
...
@@ -104,22 +104,22 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
...
@@ -104,22 +104,22 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
b_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
a_tensors
.
emplace_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
a_tensors
.
emplace_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
K
,
gemm_
shape
s
[
i
].
StrideA
,
ALayout
{})));
gemm_
desc
s
[
i
].
M
,
gemm_
desc
s
[
i
].
K
,
gemm_
desc
s
[
i
].
StrideA
,
ALayout
{})));
b_tensors
.
emplace_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
b_tensors
.
emplace_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
K
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
StrideB
,
BLayout
{})));
gemm_
desc
s
[
i
].
K
,
gemm_
desc
s
[
i
].
N
,
gemm_
desc
s
[
i
].
StrideB
,
BLayout
{})));
c_host_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
c_host_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
StrideC
,
CLayout
{})));
gemm_
desc
s
[
i
].
M
,
gemm_
desc
s
[
i
].
N
,
gemm_
desc
s
[
i
].
StrideC
,
CLayout
{})));
c_device_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
c_device_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_
shape
s
[
i
].
M
,
gemm_
shape
s
[
i
].
N
,
gemm_
shape
s
[
i
].
StrideC
,
CLayout
{})));
gemm_
desc
s
[
i
].
M
,
gemm_
desc
s
[
i
].
N
,
gemm_
desc
s
[
i
].
StrideC
,
CLayout
{})));
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
}
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
a_tensors_device
.
emplace_back
(
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()));
...
@@ -144,7 +144,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
...
@@ -144,7 +144,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
auto
invoker_ptr
=
groupedGemmPtr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
groupedGemmPtr
->
MakeInvokerPointer
();
auto
argument_ptr
=
groupedGemmPtr
->
MakeArgumentPointer
(
auto
argument_ptr
=
groupedGemmPtr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_c
,
gemm_
shape
s
,
a_element_op
,
b_element_op
,
c_element_op
);
p_a
,
p_b
,
p_c
,
gemm_
desc
s
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
groupedGemmPtr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
DeviceMem
gemm_desc_workspace
(
groupedGemmPtr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
...
@@ -152,7 +152,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
...
@@ -152,7 +152,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
invoker_ptr
->
Run
(
argument_ptr
.
get
());
invoker_ptr
->
Run
(
argument_ptr
.
get
());
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
shape
s
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_
desc
s
.
size
();
i
++
)
{
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment