Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e86a2748
Commit
e86a2748
authored
Mar 21, 2022
by
Jing Zhang
Browse files
fixed comments: unified blk2ctile; add test
parent
8df7bd01
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
286 additions
and
25 deletions
+286
-25
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+3
-3
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+43
-14
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+5
-8
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/grouped_gemm/CMakeLists.txt
test/grouped_gemm/CMakeLists.txt
+3
-0
test/grouped_gemm/grouped_gemm_fp16.cpp
test/grouped_gemm/grouped_gemm_fp16.cpp
+231
-0
No files found.
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
e86a2748
...
...
@@ -76,9 +76,9 @@ template <typename AElementwiseOperation,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>
p_a
,
std
::
vector
<
const
void
*>
p_b
,
std
::
vector
<
void
*>
p_c
,
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>
&
p_a
,
std
::
vector
<
const
void
*>
&
p_b
,
std
::
vector
<
void
*>
&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
e86a2748
...
...
@@ -223,6 +223,35 @@ struct DeviceGroupedGemmXdl
CThreadTransferDstScalarPerVector
,
NumPrefetch
>
;
struct
GroupedGemmBlock2CTileMap
{
GroupedGemmBlock2CTileMap
()
{
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
);
BlockStart_
=
-
1
;
}
GroupedGemmBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
,
ck
::
index_t
BlockStart
)
{
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
M01
,
N01
);
BlockStart_
=
BlockStart
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_2_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
I0
]
-
BlockStart_
));
}
private:
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
ck
::
index_t
BlockStart_
;
};
struct
GemmDescKernelArg
{
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
...
...
@@ -232,21 +261,21 @@ struct DeviceGroupedGemmXdl
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
Default
Block2CTileMap
block_2_ctile_map_
;
GroupedGemm
Block2CTileMap
grouped_gemm_
block_2_ctile_map_
;
const
ADataType
*
a_ptr
;
const
BDataType
*
b_ptr
;
CDataType
*
c_ptr
;
ck
::
index_t
BlockStart
,
BlockEnd
;
ck
::
index_t
BlockStart
_
,
BlockEnd
_
;
};
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
const
void
*>
p_a
,
std
::
vector
<
const
void
*>
p_b
,
std
::
vector
<
void
*>
p_c
,
Argument
(
std
::
vector
<
const
void
*>
&
p_a
,
std
::
vector
<
const
void
*>
&
p_b
,
std
::
vector
<
void
*>
&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
index_t
M01
,
index_t
N01
,
...
...
@@ -301,15 +330,15 @@ struct DeviceGroupedGemmXdl
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
const
auto
block_2_ctile_map_
=
Gr
idwiseGemm
::
MakeDefault
Block2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
const
auto
grouped_gemm_
block_2_ctile_map_
=
Gr
oupedGemm
Block2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
,
BlockStart
);
gemm_desc_kernel_arg_
.
push_back
(
GemmDescKernelArg
{
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
block_2_ctile_map_
,
grouped_gemm_
block_2_ctile_map_
,
static_cast
<
const
ADataType
*>
(
p_a
[
i
]),
static_cast
<
const
BDataType
*>
(
p_b
[
i
]),
static_cast
<
CDataType
*>
(
p_c
[
i
]),
...
...
@@ -470,9 +499,9 @@ struct DeviceGroupedGemmXdl
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>
p_a
,
std
::
vector
<
const
void
*>
p_b
,
std
::
vector
<
void
*>
p_c
,
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>
&
p_a
,
std
::
vector
<
const
void
*>
&
p_b
,
std
::
vector
<
void
*>
&
p_c
,
std
::
vector
<
GemmShape
>
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
@@ -484,9 +513,9 @@ struct DeviceGroupedGemmXdl
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>
p_a
,
std
::
vector
<
const
void
*>
p_b
,
std
::
vector
<
void
*>
p_c
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>
&
p_a
,
std
::
vector
<
const
void
*>
&
p_b
,
std
::
vector
<
void
*>
&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
e86a2748
...
...
@@ -80,11 +80,10 @@ __global__ void
#if 1
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
&&
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
_
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
_
&&
i
<
group_count
)
{
auto
group_id
=
i
;
const
index_t
block_id_grp
=
block_id
-
gemm_desc_
[
group_id
].
BlockStart
;
auto
group_id
=
i
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
gemm_desc_
[
group_id
].
a_ptr
,
...
...
@@ -97,8 +96,7 @@ __global__ void
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_
[
group_id
].
block_2_ctile_map_
,
block_id_grp
);
gemm_desc_
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
}
});
#else
...
...
@@ -426,8 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
,
ck
::
index_t
block_id
=
get_block_1d_id
())
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
...
...
@@ -440,7 +437,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_
id
));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_
block_
1d_id
()
));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
...
...
test/CMakeLists.txt
View file @
e86a2748
...
...
@@ -35,6 +35,7 @@ add_subdirectory(space_filling_curve)
add_subdirectory
(
conv_util
)
add_subdirectory
(
reference_conv_fwd
)
add_subdirectory
(
gemm
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
gemm_split_k
)
add_subdirectory
(
conv2d_fwd
)
add_subdirectory
(
convnd_fwd
)
...
...
test/grouped_gemm/CMakeLists.txt
0 → 100644
View file @
e86a2748
add_test_executable
(
test_grouped_gemm_fp16 grouped_gemm_fp16.cpp
)
target_link_libraries
(
test_grouped_gemm_fp16 PRIVATE host_tensor
)
target_link_libraries
(
test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance
)
test/grouped_gemm/grouped_gemm_fp16.cpp
0 → 100644
View file @
e86a2748
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_grouped_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "test_util.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGroupedGemmPtr_
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_grouped_gemm_instance
{
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmPtr_
>&
);
}
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
template
<
typename
T
>
static
bool
check_err
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
float
max_diff
=
1e-2
;
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
if
(
max_diff
<
diff
)
{
std
::
cout
<<
double
(
ref
.
mData
[
i
])
<<
","
<<
double
(
result
.
mData
[
i
])
<<
std
::
endl
;
return
false
;
}
}
return
true
;
}
bool
TestGroupedGemm
(
DeviceGroupedGemmPtr_
&
groupedGemmPtr
)
{
int
group_count
=
4
;
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmShape
>
gemm_shapes
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
gemm_shapes
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
256
+
256
*
i
;
int
N
=
128
+
128
*
i
;
int
K
=
64
+
64
*
i
;
int
AStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
?
K
:
M
;
int
BStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
?
N
:
K
;
int
CStride
=
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
?
N
:
M
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
AStride
,
BStride
,
CStride
});
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_device_tensors
;
a_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
c_host_tensors
.
reserve
(
group_count
);
c_device_tensors
.
reserve
(
group_count
);
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
c_tensors_device
;
a_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
a_tensors
.
emplace_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_shapes
[
i
].
M
,
gemm_shapes
[
i
].
K
,
gemm_shapes
[
i
].
StrideA
,
ALayout
{})));
b_tensors
.
emplace_back
(
Tensor
<
BDataType
>
(
f_host_tensor_descriptor
(
gemm_shapes
[
i
].
K
,
gemm_shapes
[
i
].
N
,
gemm_shapes
[
i
].
StrideB
,
BLayout
{})));
c_host_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_shapes
[
i
].
M
,
gemm_shapes
[
i
].
N
,
gemm_shapes
[
i
].
StrideC
,
CLayout
{})));
c_device_tensors
.
emplace_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
gemm_shapes
[
i
].
M
,
gemm_shapes
[
i
].
N
,
gemm_shapes
[
i
].
StrideC
,
CLayout
{})));
// std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
//<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
//<< std::endl;
flop
+=
std
::
size_t
(
2
)
*
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
*
gemm_shapes
[
i
].
N
;
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
()));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
}
auto
a_element_op
=
PassThrough
{};
auto
b_element_op
=
PassThrough
{};
auto
c_element_op
=
PassThrough
{};
// do GEMM
auto
invoker_ptr
=
groupedGemmPtr
->
MakeInvokerPointer
();
auto
argument_ptr
=
groupedGemmPtr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
groupedGemmPtr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
invoker_ptr
->
Run
(
argument_ptr
.
get
());
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_tensors
[
i
],
b_tensors
[
i
],
c_host_tensors
[
i
],
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
bool
res
=
check_err
(
c_device_tensors
[
i
],
c_host_tensors
[
i
]);
std
::
cout
<<
"group_id: "
<<
i
<<
(
res
?
" SUCCESS"
:
" FAILURE"
)
<<
std
::
endl
;
if
(
!
res
)
return
false
;
}
return
true
;
}
}
// anonymous namespace
int
main
()
{
std
::
vector
<
DeviceGroupedGemmPtr_
>
groupedGemmPtrs
;
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
groupedGemmPtrs
);
bool
res
=
true
;
for
(
auto
&
gemmPtr
:
groupedGemmPtrs
)
{
res
&=
TestGroupedGemm
(
gemmPtr
);
}
std
::
cout
<<
"TestGroupedGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment