Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
88578483
Commit
88578483
authored
May 27, 2022
by
Jing Zhang
Browse files
add SetWorkSpacePointer
parent
51a549c9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
43 deletions
+25
-43
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+6
-10
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+2
-0
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+0
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+11
-31
test/grouped_gemm/grouped_gemm_fp16.cpp
test/grouped_gemm/grouped_gemm_fp16.cpp
+6
-1
No files found.
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
88578483
...
...
@@ -192,17 +192,13 @@ int main(int argc, char* argv[])
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
gemm_shapes
.
size
()));
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
gemm_desc_workspace
.
GetDeviceBuffer
(),
a_element_op
,
b_element_op
,
c_element_op
);
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
88578483
...
...
@@ -42,6 +42,8 @@ struct BaseOperator
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
,
void
*
)
const
{}
virtual
~
BaseOperator
()
{}
};
...
...
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
88578483
...
...
@@ -51,7 +51,6 @@ struct DeviceGroupedGemm : public BaseOperator
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
void
*
gemm_descs_args_workspace
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
88578483
...
...
@@ -350,7 +350,6 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
void
*
gemm_descs_args_workspace
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
...
...
@@ -364,7 +363,7 @@ struct DeviceGroupedGemmXdl
{
grid_size_
=
0
;
gemm_descs_args_workspace_
=
gemm_descs_args_workspace
;
gemm_descs_args_workspace_
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
...
...
@@ -490,11 +489,6 @@ struct DeviceGroupedGemmXdl
}
}
// void* gemm_descs_args_workspace;
// hipGetErrorString(hipMalloc(
// &gemm_descs_args_workspace, arg.gemm_desc_kernel_arg_.size() *
// sizeof(GemmDescKernelArg)));
hipGetErrorString
(
hipMemcpy
(
arg
.
gemm_descs_args_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
...
...
@@ -587,21 +581,11 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>
gemm_shapes
,
void
*
gemm_descs_args_workspace
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_shapes
,
gemm_descs_args_workspace
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -611,22 +595,13 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
void
*
gemm_descs_args_workspace
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
gemm_descs_args_workspace
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
...
...
@@ -658,9 +633,14 @@ struct DeviceGroupedGemmXdl
return
str
.
str
();
}
static
size_t
GetWorkSpaceSize
(
const
index_t
group_count
)
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmDescKernelArg
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
workspace_ptr
)
const
override
{
return
group_count
*
sizeof
(
G
emm
D
esc
KernelArg
)
;
dynamic_cast
<
Argument
*>
(
p_arg
)
->
g
emm
_d
esc
s_args_workspace_
=
workspace_ptr
;
}
};
...
...
test/grouped_gemm/grouped_gemm_fp16.cpp
View file @
88578483
...
...
@@ -141,10 +141,15 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
auto
c_element_op
=
PassThrough
{};
// do GEMM
auto
invoker_ptr
=
groupedGemmPtr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
groupedGemmPtr
->
MakeInvokerPointer
();
auto
argument_ptr
=
groupedGemmPtr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
groupedGemmPtr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
groupedGemmPtr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
invoker_ptr
->
Run
(
argument_ptr
.
get
());
for
(
std
::
size_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment