Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5a5468f4
"profiler/vscode:/vscode.git/clone" did not exist on "959ddcf895c98f6948e62d33859d0aebed14f533"
Commit
5a5468f4
authored
Jul 18, 2023
by
Jing Zhang
Browse files
add SetDeviceKernelArgs
parent
3165d5d7
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
126 additions
and
74 deletions
+126
-74
client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp
...xample/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp
+30
-3
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+5
-5
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+89
-66
No files found.
client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp
View file @
5a5468f4
...
...
@@ -60,6 +60,8 @@ int main()
std
::
vector
<
int
>
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideEs
;
int
sum_of_m
=
0
;
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
Ms
.
push_back
(
256
+
256
*
distrib
(
gen
));
...
...
@@ -69,6 +71,8 @@ int main()
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
StrideEs
.
push_back
(
std
::
is_same
<
Row
,
ELayout
>::
value
?
Ns
[
i
]
:
Ms
[
i
]);
sum_of_m
+=
Ms
[
i
];
}
auto
f_matrix_space_size
=
...
...
@@ -102,6 +106,10 @@ int main()
gemm_descs
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GroupedGemmKernelArgument
<>>
grouped_gemm_kernel_args_
;
grouped_gemm_kernel_args_
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
a_dev_bufs
.
emplace_back
(
sizeof
(
ADataType
)
*
...
...
@@ -111,11 +119,23 @@ int main()
e_dev_bufs
.
emplace_back
(
sizeof
(
EDataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ns
[
i
],
StrideEs
[
i
],
ELayout
{}));
gemm_descs
.
push_back
({
Ms
[
i
]
,
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideEs
[
i
],
{}});
gemm_descs
.
push_back
({
sum_of_m
,
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
StrideEs
[
i
],
{}});
p_a
.
push_back
(
a_dev_bufs
[
i
].
GetDeviceBuffer
());
p_b
.
push_back
(
b_dev_bufs
[
i
].
GetDeviceBuffer
());
p_e
.
push_back
(
e_dev_bufs
[
i
].
GetDeviceBuffer
());
grouped_gemm_kernel_args_
.
push_back
({
a_dev_bufs
[
i
].
GetDeviceBuffer
(),
b_dev_bufs
[
i
].
GetDeviceBuffer
(),
{},
e_dev_bufs
[
i
].
GetDeviceBuffer
(),
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
{},
StrideEs
[
i
]});
}
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemm
<
ALayout
,
...
...
@@ -162,13 +182,20 @@ int main()
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
SimpleDeviceMem
gemm_desc_workspace
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
//
op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
hipMemcpy
(
gemm_desc_workspace
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()),
hipMemcpyHostToDevice
);
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
j
=
0
;
j
<
gemm_descs
.
size
();
++
j
)
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
5a5468f4
...
...
@@ -223,7 +223,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
//
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
hip_check_error
(
hipMemcpy
(
gemm_desc_workspace
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
...
...
@@ -237,7 +237,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem"
);
}
invoker
.
Run
(
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
(),
StreamConfig
{
nullptr
,
false
});
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
bool
pass
=
true
;
if
(
config
.
do_verification
)
...
...
@@ -273,9 +275,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
if
(
config
.
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
(),
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
5a5468f4
...
...
@@ -66,6 +66,8 @@ struct DeviceGroupedGemm : public BaseOperator
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
=
0
;
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
5a5468f4
...
...
@@ -564,6 +564,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{
grid_size_
=
0
;
grouped_gemm_kernel_args_dev
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
||
...
...
@@ -713,6 +715,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
std
::
vector
<
Tuple
<
index_t
,
index_t
>>
a_mtx_mraw_kraw_
;
std
::
vector
<
Tuple
<
index_t
,
index_t
>>
b_mtx_nraw_kraw_
;
const
void
*
grouped_gemm_kernel_args_dev
;
index_t
grid_size_
;
};
...
...
@@ -721,65 +725,15 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
void
*
grouped_gemm_kernel_args_dev
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
bool
has_main_k_block_loop
=
true
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl
<
GridwiseGemm
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
Block2ETileMap
,
GroupedGemmBlock2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
has_main_k_block_loop_
>
;
const
index_t
grid_size_grp
=
arg
.
gemm_desc_kernel_arg_
[
0
].
BlockEnd_
-
arg
.
gemm_desc_kernel_arg_
[
0
].
BlockStart_
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
grouped_gemm_kernel_args_dev
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
grid_size_grp
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
};
if
(
has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
bool
has_main_k_block_loop
=
true
;
#if 1
std
::
vector
<
GroupedGemmKernelArgument
<
NumDTensor
>>
grouped_gemm_kernel_args
;
grouped_gemm_kernel_args
.
reserve
(
arg
.
gemm_desc_kernel_arg_
.
size
());
#endif
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
...
...
@@ -824,13 +778,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
}
if
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_ptr_
==
nullptr
||
arg
.
gemm_desc_kernel_arg_
[
i
].
b_ptr_
==
nullptr
||
arg
.
gemm_desc_kernel_arg_
[
i
].
e_ptr_
==
nullptr
)
{
throw
std
::
runtime_error
(
"wrong! p_a/b/c_grid is nullptr"
);
}
#if 1
grouped_gemm_kernel_args
.
push_back
(
GroupedGemmKernelArgument
<
NumDTensor
>
{
arg
.
gemm_desc_kernel_arg_
[
i
].
a_ptr_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_ptr_
,
...
...
@@ -843,16 +791,80 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideB_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideDs_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideE_
});
#endif
}
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl
<
GridwiseGemm
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
Block2ETileMap
,
GroupedGemmBlock2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
has_main_k_block_loop_
>
;
const
index_t
grid_size_grp
=
arg
.
gemm_desc_kernel_arg_
[
0
].
BlockEnd_
-
arg
.
gemm_desc_kernel_arg_
[
0
].
BlockStart_
;
const
void
*
kernel_args_dev
=
nullptr
;
if
(
arg
.
grouped_gemm_kernel_args_dev
!=
nullptr
)
{
kernel_args_dev
=
arg
.
grouped_gemm_kernel_args_dev
;
}
else
{
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
if
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_ptr_
==
nullptr
||
arg
.
gemm_desc_kernel_arg_
[
i
].
b_ptr_
==
nullptr
||
arg
.
gemm_desc_kernel_arg_
[
i
].
e_ptr_
==
nullptr
)
{
throw
std
::
runtime_error
(
"wrong! p_a/b/c_grid is nullptr"
);
}
}
hipGetErrorString
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
hipGetErrorString
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
grouped_gemm_kernel_args
.
data
(),
grouped_gemm_kernel_args
.
size
()
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
float
ave_time
=
Run
(
arg
,
arg
.
p_workspace_
,
stream_config
);
kernel_args_dev
=
arg
.
p_workspace_
;
}
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
kernel_args_dev
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
grid_size_grp
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
};
if
(
has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
...
...
@@ -967,6 +979,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return
str
.
str
();
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
kernel_args
)
{
arg
.
grouped_gemm_kernel_args_dev
=
kernel_args
;
}
// polymorphic
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kernel_args
);
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment