Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bb9c4a89
Commit
bb9c4a89
authored
Mar 16, 2022
by
Jing Zhang
Browse files
fixed comments: reserve, seperate ptr and gemm_shapes
parent
c3952566
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
34 deletions
+67
-34
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+23
-8
include/ck/config.hpp
include/ck/config.hpp
+0
-7
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+10
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+34
-17
include/ck/tensor_operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
..._operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
+0
-1
No files found.
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
bb9c4a89
...
...
@@ -79,7 +79,11 @@ int main(int argc, char* argv[])
int
group_count
=
4
;
// GEMM shape
std
::
vector
<
ck
::
GemmShape
>
gemm_shapes
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmShape
>
gemm_shapes
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
void
*>
p_c
;
gemm_shapes
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
...
...
@@ -87,7 +91,7 @@ int main(int argc, char* argv[])
int
N
=
128
+
128
*
i
;
int
K
=
64
+
64
*
i
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
nullptr
,
nullptr
,
nullptr
});
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
});
}
auto
f_host_tensor_descriptor
=
...
...
@@ -105,14 +109,24 @@ int main(int argc, char* argv[])
};
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
;
std
::
vector
<
Tensor
<
BDataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_device_tensors
;
a_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
c_host_tensors
.
reserve
(
group_count
);
c_device_tensors
.
reserve
(
group_count
);
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b_tensors_device
,
c_tensors_device
;
a_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
...
...
@@ -164,9 +178,9 @@ int main(int argc, char* argv[])
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
gemm_shapes
[
i
].
p_a
=
a_tensors_device
[
i
]
->
GetDeviceBuffer
();
gemm_shapes
[
i
].
p_b
=
b_tensors_device
[
i
]
->
GetDeviceBuffer
();
gemm_shapes
[
i
].
p_c
=
c_tensors_device
[
i
]
->
GetDeviceBuffer
();
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
()
)
;
p_b
.
push_back
(
b_tensors_device
[
i
]
->
GetDeviceBuffer
()
)
;
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
()
)
;
}
auto
a_element_op
=
AElementOp
{};
...
...
@@ -174,9 +188,10 @@ int main(int argc, char* argv[])
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
gemm_shapes
,
a_element_op
,
b_element_op
,
c_element_op
);
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
include/ck/config.hpp
View file @
bb9c4a89
...
...
@@ -177,12 +177,5 @@ enum ActivTypeEnum_t
using
index_t
=
int32_t
;
using
long_index_t
=
int64_t
;
struct
GemmShape
{
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
void
*
p_a
,
*
p_b
,
*
p_c
;
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
bb9c4a89
...
...
@@ -8,6 +8,12 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
struct
GemmShape
{
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
...
...
@@ -70,7 +76,10 @@ template <typename AElementwiseOperation,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>&
gemm_shapes
,
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>
p_a
,
std
::
vector
<
const
void
*>
p_b
,
std
::
vector
<
void
*>
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
bb9c4a89
...
...
@@ -244,7 +244,10 @@ struct DeviceGroupedGemmXdl
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
GemmShape
>&
gemm_shapes
,
Argument
(
std
::
vector
<
const
void
*>
p_a
,
std
::
vector
<
const
void
*>
p_b
,
std
::
vector
<
void
*>
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
...
...
@@ -256,12 +259,20 @@ struct DeviceGroupedGemmXdl
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
grid_size
=
0
;
group_count_
=
0
;
grid_size_
=
0
;
group_count_
=
static_cast
<
int
>
(
gemm_shapes
.
size
());
if
(
!
(
group_count_
==
p_a
.
size
()
&&
group_count_
==
p_b
.
size
()
&&
group_count_
==
p_c
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != P_a/b/c.size"
);
}
gemm_desc_kernel_arg_
.
reserve
(
group_count_
);
for
(
index_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
group_count_
++
;
const
index_t
M
=
gemm_shapes
[
i
].
M
;
const
index_t
N
=
gemm_shapes
[
i
].
N
;
const
index_t
K
=
gemm_shapes
[
i
].
K
;
...
...
@@ -279,10 +290,10 @@ struct DeviceGroupedGemmXdl
const
index_t
grid_size_grp
=
GridwiseGemm
::
CalculateGridSize
(
c_grid_desc_m_n_
);
const
index_t
BlockStart
=
grid_size
;
const
index_t
BlockEnd
=
grid_size
+
grid_size_grp
;
const
index_t
BlockStart
=
grid_size
_
;
const
index_t
BlockEnd
=
grid_size
_
+
grid_size_grp
;
grid_size
+=
grid_size_grp
;
grid_size
_
+=
grid_size_grp
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
...
...
@@ -299,9 +310,9 @@ struct DeviceGroupedGemmXdl
c_grid_desc_m_n_
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
block_2_ctile_map_
,
static_cast
<
const
ADataType
*>
(
gemm_shapes
[
i
].
p_a
),
static_cast
<
const
BDataType
*>
(
gemm_shapes
[
i
].
p_b
),
static_cast
<
CDataType
*>
(
gemm_shapes
[
i
].
p_c
),
static_cast
<
const
ADataType
*>
(
p_a
[
i
]
),
static_cast
<
const
BDataType
*>
(
p_b
[
i
]
),
static_cast
<
CDataType
*>
(
p_c
[
i
]
),
BlockStart
,
BlockEnd
});
}
...
...
@@ -318,7 +329,7 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
index_t
grid_size
;
index_t
grid_size
_
;
};
// Invoker
...
...
@@ -395,7 +406,7 @@ struct DeviceGroupedGemmXdl
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
grid_size
),
dim3
(
arg
.
grid_size
_
),
dim3
(
BlockSize
),
0
,
gemm_desc_kernel_arg_arg
,
...
...
@@ -419,7 +430,7 @@ struct DeviceGroupedGemmXdl
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
grid_size
),
dim3
(
arg
.
grid_size
_
),
dim3
(
BlockSize
),
0
,
gemm_desc_kernel_arg_arg
,
...
...
@@ -459,25 +470,31 @@ struct DeviceGroupedGemmXdl
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>
p_a
,
std
::
vector
<
const
void
*>
p_b
,
std
::
vector
<
void
*>
p_c
,
std
::
vector
<
GemmShape
>
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>&
gemm_shapes
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>
p_a
,
std
::
vector
<
const
void
*>
p_b
,
std
::
vector
<
void
*>
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
p_a
,
p_b
,
p_c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
View file @
bb9c4a89
...
...
@@ -85,7 +85,6 @@ __global__ void
c_element_op
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
block_id_grp
);
#endif
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment