Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
27f1c8cb
Commit
27f1c8cb
authored
Jul 17, 2022
by
Jing Zhang
Browse files
use array
parent
b78c8719
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
17 additions
and
14 deletions
+17
-14
example/15_grouped_gemm/grouped_gemm_bias_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_bias_xdl_fp16.cpp
+1
-1
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+3
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+11
-10
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+1
-1
No files found.
example/15_grouped_gemm/grouped_gemm_bias_xdl_fp16.cpp
View file @
27f1c8cb
...
...
@@ -82,7 +82,7 @@ int main(int argc, char* argv[])
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
std
::
vector
<
const
void
*>
p_a
,
p_b
;
std
::
vector
<
std
::
vector
<
const
void
*>>
p_ds
;
std
::
vector
<
std
::
array
<
const
void
*
,
1
>>
p_ds
;
std
::
vector
<
void
*>
p_c
;
gemm_descs
.
reserve
(
group_count
);
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
27f1c8cb
...
...
@@ -200,7 +200,7 @@ int main(int argc, char* argv[])
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
std
::
vector
<
std
::
vector
<
const
void
*>>
p_Ds
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
p_Ds
=
{};
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
27f1c8cb
...
...
@@ -28,10 +28,12 @@ template <typename ALayout,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
std
::
vector
<
const
void
*>>&
p_ds
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_ds
,
std
::
vector
<
void
*>&
p_e
,
std
::
vector
<
GemmDesc
>&
gemm_desc
,
AElementwiseOperation
a_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
27f1c8cb
...
...
@@ -532,7 +532,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
{
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
vector
<
const
void
*>>&
p_Ds
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
...
...
@@ -755,7 +755,7 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
vector
<
const
void
*>>&
p_Ds
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
AElementwiseOperation
a_element_op
,
...
...
@@ -769,9 +769,10 @@ struct DeviceGroupedGemmXdl : public DeviceGroupedGemm<ALayout,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
vector
<
const
void
*>>&
p_Ds
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
27f1c8cb
...
...
@@ -175,7 +175,7 @@ bool profile_grouped_gemm_impl(int do_verification,
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
auto
p_ds
=
std
::
vector
<
std
::
vector
<
const
void
*>>
{};
auto
p_ds
=
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
{};
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
op_ptrs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment