Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c0972543
Commit
c0972543
authored
Jul 19, 2023
by
Jing Zhang
Browse files
clean; fixed comments
parent
1c485e01
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
7 additions
and
67 deletions
+7
-67
client_example/20_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
.../20_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
+1
-5
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
...e/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
+3
-9
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+2
-5
include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
...sor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+0
-46
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+0
-1
No files found.
client_example/20_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
View file @
c0972543
...
...
@@ -56,9 +56,6 @@ struct SimpleDeviceMem
int
main
()
{
std
::
mt19937
gen
(
19391
);
std
::
uniform_int_distribution
<>
distrib
(
1
,
10
);
std
::
vector
<
int
>
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideEs
;
int
sum_of_m
=
0
;
...
...
@@ -123,7 +120,7 @@ int main()
e_dev_bufs
.
emplace_back
(
sizeof
(
EDataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ns
[
i
],
StrideEs
[
i
],
ELayout
{}));
gemm_descs
.
push_back
({
sum_of_m
,
Ns
[
i
],
Ks
[
i
],
0
,
StrideBs
[
i
],
0
,
{
0
}});
gemm_descs
.
push_back
({
sum_of_m
,
Ns
[
i
],
Ks
[
i
],
1
,
StrideBs
[
i
],
1
,
{
0
}});
p_e
.
push_back
(
e_dev_bufs
[
i
].
GetDeviceBuffer
());
...
...
@@ -248,7 +245,6 @@ int main()
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
SimpleDeviceMem
gemm_desc_workspace
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
// op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
View file @
c0972543
...
...
@@ -202,13 +202,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
p_Cs
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
gemm_descs
.
push_back
({
sum_of_m
,
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
0
,
problem_size
.
stride_Bs
[
i
],
0
,
{
0
}});
gemm_descs
.
push_back
(
{
1
,
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
1
,
problem_size
.
stride_Bs
[
i
],
1
,
{
0
}});
grouped_gemm_kernel_args_
.
push_back
(
{
a_tensors_device
[
i
]
->
GetDeviceBuffer
(),
...
...
@@ -320,8 +315,7 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
problem_size
.
Ms
=
{
167
,
0
,
177
,
181
,
153
,
0
,
156
,
173
,
645
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
c0972543
...
...
@@ -188,9 +188,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm_descs
.
push_back
({
sum_of_m
,
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
]
,
1
,
problem_size
.
stride_Bs
[
i
],
problem_size
.
stride_Cs
[
i
]
,
1
,
{}});
grouped_gemm_kernel_args_
.
push_back
({
a_tensors_device
[
i
]
->
GetDeviceBuffer
(),
...
...
@@ -223,8 +223,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
// gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
hip_check_error
(
hipMemcpy
(
gemm_desc_workspace
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
gemm
.
GetWorkSpaceSize
(
&
argument
),
...
...
@@ -286,7 +284,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
return
pass
;
}
// int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int
main
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
View file @
c0972543
...
...
@@ -4,7 +4,7 @@
#pragma once
#include <iostream>
#include <
vector
>
#include <
array
>
#include "device_grouped_gemm.hpp"
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
c0972543
...
...
@@ -56,47 +56,11 @@ __global__ void
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
#if 0
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
left <= right)
{
if(block_id < gemm_desc_ptr[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
#endif
const
index_t
group_id
=
block_id
/
grid_size_grp
;
if
(
group_id
>=
group_count
)
return
;
#if 0
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
gemm_desc_ptr[group_id].ds_ptr_,
gemm_desc_ptr[group_id].e_ptr_,
p_shared,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_,
gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_,
gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].block_2_etile_map_);
#else
const
index_t
M
=
gemm_desc_ptr
[
group_id
].
M
;
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
...
...
@@ -158,9 +122,6 @@ __global__ void
m_id
+=
1
;
}
while
(
m_id
<
m_loops
);
#endif
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
...
...
@@ -644,9 +605,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
// std::cout << "grp id: " << group_id << " grid_size: " << grid_size_grp <<
// std::endl;
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
...
...
@@ -731,11 +689,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{
bool
has_main_k_block_loop
=
true
;
#if 1
std
::
vector
<
GroupedGemmKernelArgument
<
NumDTensor
>>
grouped_gemm_kernel_args
;
grouped_gemm_kernel_args
.
reserve
(
arg
.
gemm_desc_kernel_arg_
.
size
());
#endif
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
...
...
@@ -788,7 +744,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
}
#if 1
grouped_gemm_kernel_args
.
push_back
(
GroupedGemmKernelArgument
<
NumDTensor
>
{
arg
.
gemm_desc_kernel_arg_
[
i
].
a_ptr_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_ptr_
,
...
...
@@ -801,7 +756,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideB_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideDs_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideE_
});
#endif
}
float
ave_time
=
0
;
...
...
profiler/src/CMakeLists.txt
View file @
c0972543
...
...
@@ -75,7 +75,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_softmax_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batchnorm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_bias_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment