Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f41a265a
Commit
f41a265a
authored
Dec 01, 2023
by
Adam Osewski
Browse files
Fix allocation and setting workspace pointer.
parent
defa2071
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
..._grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
+6
-5
No files found.
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
View file @
f41a265a
...
@@ -228,10 +228,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -228,10 +228,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_arg_dev_mem
(
gemm
.
GetDeviceKernelArgSize
(
&
argument
));
DeviceMem
gemm_arg_dev_mem
(
gemm
.
GetDeviceKernelArgSize
(
&
argument
));
DeviceMem
gemm_workspace_dev
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_workspace_dev
.
GetDeviceBuffer
());
hip_check_error
(
hipMemcpy
(
gemm_arg_dev_mem
.
GetDeviceBuffer
(),
hip_check_error
(
hipMemcpy
(
gemm_arg_dev_mem
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
grouped_gemm_kernel_args_
.
data
(),
gemm
.
GetDeviceKernelArgSize
(
&
argument
),
gemm
.
GetDeviceKernelArgSize
(
&
argument
),
...
@@ -247,7 +243,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -247,7 +243,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_arg_dev_mem
.
GetDeviceBuffer
());
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_arg_dev_mem
.
GetDeviceBuffer
());
gemm
.
SetKBatchSize
(
argument
,
config
.
k_batch
);
gemm
.
SetKBatchSize
(
argument
,
config
.
k_batch
);
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
DeviceMem
gemm_workspace_dev
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_workspace_dev
.
GetDeviceBuffer
());
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
1
});
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
{
{
...
@@ -289,6 +288,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -289,6 +288,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
pass
&=
ck
::
utils
::
check_err
(
c_device_tensors
[
i
],
c_host_tensors
[
i
]);
pass
&=
ck
::
utils
::
check_err
(
c_device_tensors
[
i
],
c_host_tensors
[
i
]);
}
}
std
::
cout
<<
"Verification: "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
"!"
<<
std
::
endl
;
}
}
return
pass
;
return
pass
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment