Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0c3cfcf8
"driver/vscode:/vscode.git/clone" did not exist on "89140d16e0e16ae85af85199ee3e33c88f9f670a"
Commit
0c3cfcf8
authored
Jul 19, 2023
by
Jing Zhang
Browse files
fixed comment
parent
881ba357
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
15 deletions
+14
-15
client_example/20_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
.../20_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
+7
-5
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
...e/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
+0
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+7
-9
No files found.
client_example/20_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
View file @
0c3cfcf8
...
...
@@ -182,18 +182,19 @@ int main()
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
SimpleDeviceMem
g
emm_desc_workspace
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
SimpleDeviceMem
g
rouped_gemm_kernel_args_dev
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
hipGetErrorString
(
hipMemcpy
(
g
emm_desc_workspace
.
GetDeviceBuffer
(),
hipGetErrorString
(
hipMemcpy
(
g
rouped_gemm_kernel_args_dev
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()),
hipMemcpyHostToDevice
));
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
op_ptr
->
SetDeviceKernelArgs
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
op_ptr
->
SetDeviceKernelArgs
(
argument_ptr
.
get
(),
grouped_gemm_kernel_args_dev
.
GetDeviceBuffer
());
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
...
...
@@ -244,11 +245,12 @@ int main()
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
SimpleDeviceMem
g
emm_desc_workspace
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
SimpleDeviceMem
g
rouped_gemm_kernel_args_dev
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
op_ptr
->
SetDeviceKernelArgs
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
op_ptr
->
SetDeviceKernelArgs
(
argument_ptr
.
get
(),
grouped_gemm_kernel_args_dev
.
GetDeviceBuffer
());
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
View file @
0c3cfcf8
...
...
@@ -307,7 +307,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
return
pass
;
}
// int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int
main
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
0c3cfcf8
...
...
@@ -496,7 +496,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
ck
::
index_t
BlockStart_
,
BlockEnd_
;
};
// Argument
...
...
@@ -605,9 +604,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
if
(
group_id
*
grid_size_grp
!=
grid_size_
)
{
throw
std
::
runtime_error
(
"wrong! grid_size_grp is not identical!"
);
...
...
@@ -655,9 +651,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
local_b2c_tile_map
,
BlockStart
,
BlockEnd
});
local_b2c_tile_map
});
}
group_id
++
;
...
...
@@ -777,8 +771,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEElementwiseOperation
,
has_main_k_block_loop_
>
;
const
index_t
grid_size_grp
=
arg
.
gemm_desc_kernel_arg_
[
0
].
BlockEnd_
-
arg
.
gemm_desc_kernel_arg_
[
0
].
BlockStart_
;
const
index_t
grid_size_grp
=
arg
.
grid_size_
/
arg
.
group_count_
;
const
void
*
kernel_args_dev
=
nullptr
;
...
...
@@ -798,6 +791,11 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
}
}
if
(
arg
.
p_workspace_
==
nullptr
)
{
throw
std
::
runtime_error
(
"wrong! arg.p_workspace_ == nullptr"
);
}
hipGetErrorString
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
grouped_gemm_kernel_args
.
data
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment