Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e3a4b967
"include/vscode:/vscode.git/clone" did not exist on "f03a1738d93c8ffccc570e8121e0a261e9950fa6"
Commit
e3a4b967
authored
Mar 12, 2022
by
Jing Zhang
Browse files
fixed mem issue with unique_ptr
parent
8fb2b172
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
268 additions
and
910 deletions
+268
-910
composable_kernel/include/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
...de/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
+0
-698
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+28
-57
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+14
-9
include/ck/tensor_operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
..._operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
+6
-8
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+1
-0
library/src/host_tensor/device.cpp
library/src/host_tensor/device.cpp
+6
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
...device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
+6
-6
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+184
-113
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+22
-18
No files found.
composable_kernel/include/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
deleted
100644 → 0
View file @
8fb2b172
This diff is collapsed.
Click to expand it.
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
e3a4b967
...
...
@@ -81,19 +81,17 @@ int main(int argc, char* argv[])
// GEMM shape
std
::
vector
<
ck
::
GemmShape
>
gemm_shapes
;
int
A_size
=
0
,
B_size
=
0
,
C_size
=
0
;
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
256
+
256
*
i
;
int
N
=
128
+
128
*
i
;
int
K
=
64
+
64
*
i
;
//
int M = 256 + 256 * i;
//
int N = 128 + 128 * i;
//
int K = 64 + 64 * i;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
nullptr
,
nullptr
,
nullptr
});
int
M
=
3840
;
int
N
=
1024
;
int
K
=
4096
;
A_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
;
B_size
+=
gemm_shapes
[
i
].
N
*
gemm_shapes
[
i
].
K
;
C_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
N
,
N
,
nullptr
,
nullptr
,
nullptr
});
}
auto
f_host_tensor_descriptor
=
...
...
@@ -115,6 +113,10 @@ int main(int argc, char* argv[])
std
::
vector
<
Tensor
<
CDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_device_tensors
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
c_tensors_device
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
...
...
@@ -133,13 +135,10 @@ int main(int argc, char* argv[])
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
*
gemm_shapes
[
i
].
N
;
num_btype
+=
sizeof
(
ADataType
)
*
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
+
sizeof
(
BDataType
)
*
gemm_shapes
[
i
].
K
*
gemm_shapes
[
i
].
N
+
sizeof
(
CDataType
)
*
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
}
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
switch
(
init_method
)
{
case
0
:
break
;
...
...
@@ -157,38 +156,23 @@ int main(int argc, char* argv[])
}
}
DeviceMem
a_tensors_device_buf
(
sizeof
(
ADataType
)
*
A_size
);
DeviceMem
b_tensors_device_buf
(
sizeof
(
BDataType
)
*
B_size
);
DeviceMem
c_tensors_device_buf
(
sizeof
(
CDataType
)
*
C_size
);
std
::
vector
<
ADataType
>
a_tensors_data
,
b_tensors_data
,
c_tensors_data
;
A_size
=
0
;
B_size
=
0
;
C_size
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
a_tensors_data
.
insert
(
a_tensors_data
.
end
(),
a_tensors
[
i
].
mData
.
begin
(),
a_tensors
[
i
].
mData
.
end
());
b_tensors_data
.
insert
(
b_tensors_data
.
end
(),
b_tensors
[
i
].
mData
.
begin
(),
b_tensors
[
i
].
mData
.
end
());
gemm_shapes
[
i
].
p_a
=
static_cast
<
ADataType
*>
(
a_tensors_device_buf
.
GetDeviceBuffer
())
+
A_size
;
gemm_shapes
[
i
].
p_b
=
static_cast
<
BDataType
*>
(
b_tensors_device_buf
.
GetDeviceBuffer
())
+
B_size
;
gemm_shapes
[
i
].
p_c
=
static_cast
<
CDataType
*>
(
c_tensors_device_buf
.
GetDeviceBuffer
())
+
C_size
;
A_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
;
B_size
+=
gemm_shapes
[
i
].
N
*
gemm_shapes
[
i
].
K
;
C_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
a_tensors_device
.
push_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()));
b_tensors_device
.
push_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()));
c_tensors_device
.
push_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
()));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
gemm_shapes
[
i
].
p_a
=
a_tensors_device
[
i
]
->
GetDeviceBuffer
();
gemm_shapes
[
i
].
p_b
=
b_tensors_device
[
i
]
->
GetDeviceBuffer
();
gemm_shapes
[
i
].
p_c
=
c_tensors_device
[
i
]
->
GetDeviceBuffer
();
}
a_tensors_device_buf
.
ToDevice
(
a_tensors_data
.
data
());
b_tensors_device_buf
.
ToDevice
(
b_tensors_data
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
...
...
@@ -214,24 +198,11 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
c_tensors_data
.
resize
(
C_size
);
c_tensors_device_buf
.
FromDevice
(
c_tensors_data
.
data
());
C_size
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
memcpy
(
c_device_tensors
[
i
].
mData
.
data
(),
c_tensors_data
.
data
()
+
C_size
,
c_device_tensors
[
i
].
mData
.
size
()
*
sizeof
(
CDataType
));
C_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
}
if
(
do_verification
)
{
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
e3a4b967
...
...
@@ -70,7 +70,7 @@ template <typename AElementwiseOperation,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>
&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
e3a4b967
...
...
@@ -242,7 +242,7 @@ struct DeviceGroupedGemmXdl
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
Argument
(
std
::
vector
<
GemmShape
>
&
gemm_shapes
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
...
...
@@ -360,8 +360,7 @@ struct DeviceGroupedGemmXdl
if
(
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
)
!=
has_main_k0_block_loop
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k0_block_loop"
);
}
}
});
...
...
@@ -435,11 +434,17 @@ struct DeviceGroupedGemmXdl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
GemmShape_
[
0
].
a_grid_desc_k0_m_k1_
,
arg
.
GemmShape_
[
0
].
b_grid_desc_k0_n_k1_
,
arg
.
GemmShape_
[
0
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
bool
isValid
=
true
;
for
(
int
i
=
0
;
i
<
arg
.
GemmShape_
.
size
();
i
++
)
{
isValid
&=
GridwiseGemm
::
CheckValidity
(
arg
.
GemmShape_
[
i
].
a_grid_desc_k0_m_k1_
,
arg
.
GemmShape_
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
GemmShape_
[
i
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
return
isValid
;
}
// polymorphic
...
...
@@ -459,7 +464,7 @@ struct DeviceGroupedGemmXdl
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>
&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
View file @
e3a4b967
...
...
@@ -60,24 +60,22 @@ __global__ void
}
});
#else
const
GemmDesc
*
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc_
);
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc_
);
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
group_id
=
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
)
group_id
=
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
&&
i
<
group_count
)
?
i
:
group_id
;
});
const
index_t
block_id_grp
=
block_id
-
gemm_desc_ptr
[
group_id
].
BlockStart
;
const
index_t
a_offset_grp
=
gemm_desc_ptr
[
group_id
].
OffsetA
;
const
index_t
b_offset_grp
=
gemm_desc_ptr
[
group_id
].
OffsetB
;
const
index_t
c_offset_grp
=
gemm_desc_ptr
[
group_id
].
OffsetC
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_c_grid
+
c_offset_grp
,
gemm_desc_ptr
[
group_id
].
a_ptr
,
gemm_desc_ptr
[
group_id
].
b_ptr
,
gemm_desc_ptr
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_ptr
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_ptr
[
group_id
].
b_grid_desc_k0_n_k1_
,
...
...
library/include/ck/library/host_tensor/device.hpp
View file @
e3a4b967
...
...
@@ -12,6 +12,7 @@ struct DeviceMem
{
DeviceMem
()
=
delete
;
DeviceMem
(
std
::
size_t
mem_size
);
DeviceMem
(
const
DeviceMem
&
p
);
void
*
GetDeviceBuffer
();
void
ToDevice
(
const
void
*
p
);
void
FromDevice
(
void
*
p
);
...
...
library/src/host_tensor/device.cpp
View file @
e3a4b967
...
...
@@ -5,6 +5,12 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
DeviceMem
::
DeviceMem
(
const
DeviceMem
&
p
)
:
mpDeviceBuf
(
p
.
mpDeviceBuf
),
mMemSize
(
p
.
mMemSize
)
{
// hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
// hipGetErrorString(hipMemcpy(mpDeviceBuf, p.mpDeviceBuf, mMemSize, hipMemcpyDeviceToDevice));
}
void
*
DeviceMem
::
GetDeviceBuffer
()
{
return
mpDeviceBuf
;
}
void
DeviceMem
::
ToDevice
(
const
void
*
p
)
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
e3a4b967
...
...
@@ -23,9 +23,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization_t
::
Default
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
using
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
...
...
@@ -48,13 +47,14 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>
DeviceGroupedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
{});
}
}
// namespace device_grouped_gemm_instance
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
e3a4b967
...
...
@@ -16,15 +16,19 @@ namespace tensor_operation {
namespace
device
{
namespace
device_grouped_gemm_instance
{
using
DeviceGroupedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
//void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
//void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
//void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
using
DeviceGroupedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
}
// namespace device_grouped_gemm_instance
}
// namespace device
...
...
@@ -41,15 +45,15 @@ template <typename ADataType,
typename
BLayout
,
typename
CLayout
>
void
profile_grouped_gemm_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
int
nrepeat
,
std
::
vector
<
int
>
Ms
,
std
::
vector
<
int
>
Ns
,
std
::
vector
<
int
>
Ks
,
std
::
vector
<
int
>
StrideAs
,
std
::
vector
<
int
>
StrideBs
,
std
::
vector
<
int
>
StrideCs
)
int
init_method
,
bool
do_log
,
int
nrepeat
,
std
::
vector
<
int
>
Ms
,
std
::
vector
<
int
>
Ns
,
std
::
vector
<
int
>
Ks
,
std
::
vector
<
int
>
StrideAs
,
std
::
vector
<
int
>
StrideBs
,
std
::
vector
<
int
>
StrideCs
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
...
@@ -65,41 +69,48 @@ void profile_grouped_gemm_impl(int do_verification,
}
};
std
::
vector
<
Tensor
<
ADataType
>>
a_m_k
;
std
::
vector
<
Tensor
<
BDataType
>>
b_k_n
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n
;
std
::
vector
<
Tensor
<
CDataType
>>
c_m_n_device_results
;
// int A_size = 0, B_size = 0, C_size = 0;
for
(
int
i
=
0
;
i
<
Ms
.
size
();
i
++
)
{
a_m_k
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{})));
b_k_n
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{})));
c_m_n
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
a_m_k
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{})));
b_k_n
.
push_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{})));
c_m_n_device_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
std
::
cout
<<
"a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n["
<<
i
<<
"]:"
<<
c_m_n
[
i
].
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_device_results["
<<
i
<<
"]:"
<<
c_m_n_device_results
[
i
].
mDesc
<<
std
::
endl
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
},
num_thread
);
break
;
default:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
case
0
:
break
;
case
1
:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
},
num_thread
);
break
;
default:
a_m_k
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
},
num_thread
);
b_k_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
// set zero to c_device_buf
c_m_n
[
i
].
GenerateTensorValue
(
GeneratorTensor_0
<
CDataType
>
{},
num_thread
);
}
c_m_n_device_results
[
i
].
GenerateTensorValue
(
GeneratorTensor_0
<
CDataType
>
{},
num_thread
);
// A_size += a_m_k[i].mDesc.GetElementSpace();
// B_size += b_k_n[i].mDesc.GetElementSpace();
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -114,28 +125,112 @@ void profile_grouped_gemm_impl(int do_verification,
// }
std
::
vector
<
DeviceMem
>
a_device_buf
,
b_device_buf
,
c_device_buf
;
//DeviceMem a_device_buf(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace());
//DeviceMem b_device_buf(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace());
//DeviceMem c_device_buf(sizeof(CDataType) * c_m_n[i].mDesc.GetElementSpace());
// std::vector<DeviceMem> a_device_buf, b_device_buf, c_device_buf;
std
::
vector
<
void
*>
a_device_buf
,
b_device_buf
,
c_device_buf
;
// DeviceMem a_device_buf_(sizeof(ADataType) * A_size);
// DeviceMem b_device_buf_(sizeof(BDataType) * B_size);
// DeviceMem c_device_buf_(sizeof(CDataType) * C_size);
// std::vector<ADataType> a_tensors_data;
// std::vector<BDataType> b_tensors_data;
// std::vector<CDataType> c_tensors_data;
std
::
vector
<
GemmShape
>
gemm_shapes
;
// A_size = 0;
// B_size = 0;
// C_size = 0;
for
(
int
i
=
0
;
i
<
Ms
.
size
();
i
++
)
{
a_device_buf
.
push_back
(
DeviceMem
(
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
()));
b_device_buf
.
push_back
(
DeviceMem
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
()));
c_device_buf
.
push_back
(
DeviceMem
(
sizeof
(
CDataType
)
*
c_m_n
[
i
].
mDesc
.
GetElementSpace
()));
a_device_buf
[
i
].
ToDevice
(
a_m_k
[
i
].
mData
.
data
());
b_device_buf
[
i
].
ToDevice
(
b_k_n
[
i
].
mData
.
data
());
c_device_buf
[
i
].
ToDevice
(
c_m_n
[
i
].
mData
.
data
());
// a_tensors_data.insert(a_tensors_data.end(), a_m_k[i].mData.begin(),
// a_m_k[i].mData.end()); b_tensors_data.insert(b_tensors_data.end(),
// b_k_n[i].mData.begin(), b_k_n[i].mData.end());
// c_tensors_data.insert(c_tensors_data.end(), c_m_n_device_results[i].mData.begin(),
// c_m_n_device_results[i].mData.end());
void
*
a_device_buf_
,
*
b_device_buf_
,
*
c_device_buf_
;
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
a_device_buf_
),
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
()));
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
b_device_buf_
),
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
()));
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
c_device_buf_
),
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
()));
// DeviceMem a_device_buf_(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace());
// DeviceMem b_device_buf_(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace());
// DeviceMem c_device_buf_(sizeof(CDataType) *
// c_m_n_device_results[i].mDesc.GetElementSpace());
hipGetErrorString
(
hipMemcpy
(
a_device_buf_
,
a_m_k
[
i
].
mData
.
data
(),
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
(),
hipMemcpyHostToDevice
));
hipGetErrorString
(
hipMemcpy
(
b_device_buf_
,
b_k_n
[
i
].
mData
.
data
(),
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
(),
hipMemcpyHostToDevice
));
hipGetErrorString
(
hipMemcpy
(
c_device_buf_
,
c_m_n_device_results
[
i
].
mData
.
data
(),
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
(),
hipMemcpyHostToDevice
));
// a_device_buf_.ToDevice(a_m_k[i].mData.data());
// b_device_buf_.ToDevice(b_k_n[i].mData.data());
// c_device_buf_.ToDevice(c_m_n_device_results[i].mData.data());
a_device_buf
.
push_back
(
a_device_buf_
);
b_device_buf
.
push_back
(
b_device_buf_
);
c_device_buf
.
push_back
(
c_device_buf_
);
// a_device_buf.push_back(a_device_buf_);
// b_device_buf.push_back(b_device_buf_);
// c_device_buf.push_back(c_device_buf_);
// gemm_shapes.push_back({Ms[i],
// Ns[i],
// Ks[i],
// StrideAs[i],
// StrideBs[i],
// StrideCs[i],
// a_device_buf[i].GetDeviceBuffer(),
// b_device_buf[i].GetDeviceBuffer(),
// c_device_buf[i].GetDeviceBuffer()});
// printf("%p %p %p\n",
// a_device_buf[i].GetDeviceBuffer(),
// b_device_buf[i].GetDeviceBuffer(),
// c_device_buf[i].GetDeviceBuffer());
gemm_shapes
.
push_back
({
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideCs
[
i
],
a_device_buf_
,
b_device_buf_
,
c_device_buf_
});
// A_size += a_m_k[i].mDesc.GetElementSpace();
// B_size += b_k_n[i].mDesc.GetElementSpace();
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
}
// a_device_buf_.ToDevice(a_tensors_data.data());
// b_device_buf_.ToDevice(b_tensors_data.data());
// c_device_buf_.ToDevice(c_tensors_data.data());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
...
...
@@ -143,7 +238,6 @@ void profile_grouped_gemm_impl(int do_verification,
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
#if 0
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
...
...
@@ -216,24 +310,15 @@ void profile_grouped_gemm_impl(int do_verification,
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
#if
0
#if
1
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
auto
argument_ptr
=
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
gemm_ptr
->
MakeArgumentPointer
(
gemm_shapes
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck::tensor_operation::element_wise::PassThrough{},
KBatch);
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
...
...
@@ -243,6 +328,7 @@ void profile_grouped_gemm_impl(int do_verification,
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
#if 0
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
...
...
@@ -262,54 +348,36 @@ void profile_grouped_gemm_impl(int do_verification,
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
#endif
if
(
do_verification
)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
is_same<BDataType, ck::bhalf_t>::value &&
is_same<CDataType, ck::bhalf_t>::value)
{
Tensor<float> a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<float> b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<float> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<float> c_m_n_device_f32_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
// c_tensors_data.resize(C_size);
bf16_to_f32_(a_m_k, a_f32_m_k);
bf16_to_f32_(b_k_n, b_f32_k_n);
bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
// c_device_buf_.FromDevice(c_tensors_data.data());
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<float, float, float, AElementOp, BElementOp, CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
// C_size = 0;
// for(int i = 0; i < gemm_shapes.size(); i++)
//{
// memcpy(c_m_n_device_results[i].mData.data(),
// c_tensors_data.data() + C_size,
// c_m_n_device_results[i].mDesc.GetElementSpace() * sizeof(CDataType));
auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k,
b_f32_k_n,
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
//}
ref_invoker.Run(ref_argument);
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
hipGetErrorString
(
hipMemcpy
(
c_m_n_device_results
[
i
].
mData
.
data
(),
c_device_buf
[
i
],
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
(),
hipMemcpyDeviceToHost
));
check_error(c_m_n_host_result, c_m_n_device_f32_result
);
// hipGetErrorString(hipFree(c_device_buf[i])
);
if(do_log)
{
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
}
}
else
{
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor(M
, N
, StrideC, CLayout{}));
f_host_tensor_descriptor
(
M
s
[
i
],
Ns
[
i
]
,
StrideC
s
[
i
]
,
CLayout
{}));
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
...
...
@@ -322,27 +390,30 @@ void profile_grouped_gemm_impl(int do_verification,
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
[
i
],
b_k_n
[
i
],
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
check_error(c_m_n_host_result, c_m_n_device_result);
check_error
(
c_m_n_host_result
,
c_m_n_device_result
s
[
i
]
);
if
(
do_log
)
{
// LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
//<< std::endl;
// LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") <<
// std::endl;
LogRangeAsType
<
float
>
(
std::cout << "c_
host
: ", c_m_n_
host
_result.mData, ",")
std
::
cout
<<
"c_
device
: "
,
c_m_n_
device
_result
s
[
i
]
.
mData
,
","
)
<<
std
::
endl
;
// LogRangeAsType<float>(
// std::cout << "c_host : ", c_m_n_host_result.mData, ",")
//<< std::endl;
}
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
else
...
...
profiler/src/profile_grouped_gemm.cpp
View file @
e3a4b967
...
...
@@ -26,7 +26,7 @@ enum GemmDataType
INT8_INT8_INT8
,
// 3
};
std
::
vector
<
int
>
stringToArray
(
char
*
input
)
std
::
vector
<
int
>
stringToArray
(
char
*
input
)
{
std
::
vector
<
int
>
out
;
...
...
@@ -34,7 +34,8 @@ std::vector<int> stringToArray(char *input)
std
::
string
item
;
while
(
std
::
getline
(
in
,
item
,
','
))
{
while
(
std
::
getline
(
in
,
item
,
','
))
{
out
.
push_back
(
std
::
stoi
(
item
));
}
...
...
@@ -69,30 +70,33 @@ int profile_grouped_gemm(int argc, char* argv[])
const
auto
Ms
=
stringToArray
(
argv
[
8
]);
const
auto
Ns
=
stringToArray
(
argv
[
9
]);
const
auto
Ks
=
stringToArray
(
argv
[
10
]);
const
auto
StrideAs
=
stringToArray
(
argv
[
11
]);
const
auto
StrideBs
=
stringToArray
(
argv
[
12
]);
const
auto
StrideCs
=
stringToArray
(
argv
[
13
]);
for
(
int
i
=
0
;
i
<
Ms
.
size
();
i
++
)
{
std
::
cout
<<
"M: "
<<
Ms
[
i
]
<<
" N: "
<<
Ns
[
i
]
<<
" K: "
<<
Ks
[
i
]
<<
std
::
endl
;
}
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
);
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
);
}
#if 0
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment