Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
47e0607c
Commit
47e0607c
authored
Feb 17, 2023
by
Adam Osewski
Browse files
Fix clang-format
parent
e48e7f38
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
41 deletions
+30
-41
test/gemm/gemm_standalone_xdl_fp16.cpp
test/gemm/gemm_standalone_xdl_fp16.cpp
+30
-41
No files found.
test/gemm/gemm_standalone_xdl_fp16.cpp
View file @
47e0607c
...
@@ -4,11 +4,10 @@
...
@@ -4,11 +4,10 @@
#include "gemm_util.hpp"
#include "gemm_util.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
// #include "gemm_f16_nn_instance.hpp"
// #include "gemm_f16_nn_instance.hpp"
// #include "gemm_f16_nt_instance.hpp"
// #include "gemm_f16_nt_instance.hpp"
#include "gemm_f16_tn_instance.hpp"
//
#include "gemm_f16_tn_instance.hpp"
// #include "gemm_f16_tt_instance.hpp"
// #include "gemm_f16_tt_instance.hpp"
#include "gemm_wavelet_f16_nn_instance.hpp"
#include "gemm_wavelet_f16_nn_instance.hpp"
#include "gemm_wavelet_f16_nt_instance.hpp"
#include "gemm_wavelet_f16_nt_instance.hpp"
...
@@ -73,61 +72,53 @@ using ProblemDesc = std::tuple<GemmParams, LayoutConfig, OpFactoryFn>;
...
@@ -73,61 +72,53 @@ using ProblemDesc = std::tuple<GemmParams, LayoutConfig, OpFactoryFn>;
void
insertNNProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
void
insertNNProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
{
{
v
.
insert
(
std
::
begin
(
v
),
v
.
insert
(
std
::
end
(
v
),
{
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true},
// clang-format off
// add_gemm_wavelet_f16_nn_256x256},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_wavelet_f16_nn_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_wavelet_f16_nn_256x128
},
LayoutConfig
{
false
,
false
,
true
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_wavelet_f16_nn_128x128
},
add_gemm_wavelet_f16_nn_256x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_wavelet_f16_nn_128x64
}
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true},
// clang-format on
// add_gemm_wavelet_f16_nn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64}
});
});
}
}
void
insertNTProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
void
insertNTProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
{
{
v
.
insert
(
std
::
begin
(
v
),
v
.
insert
(
std
::
end
(
v
),
{
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true},
// clang-format off
// add_gemm_wavelet_f16_nt_256x256},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_wavelet_f16_nt_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_wavelet_f16_nt_256x128
},
LayoutConfig
{
false
,
true
,
true
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_wavelet_f16_nt_128x128
},
add_gemm_wavelet_f16_nt_256x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_wavelet_f16_nt_128x64
}
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true},
// clang-format on
// add_gemm_wavelet_f16_nt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64}
});
});
}
}
void
insertTNProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
void
insertTNProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
{
{
v
.
insert
(
std
::
begin
(
v
),
v
.
insert
(
std
::
end
(
v
),
{
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true},
// clang-format off
// add_gemm_wavelet_f16_tn_256x256},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_wavelet_f16_tn_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_wavelet_f16_tn_256x128
},
LayoutConfig
{
true
,
false
,
true
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_wavelet_f16_tn_128x128
},
add_gemm_wavelet_f16_tn_256x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_wavelet_f16_tn_128x64
}
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true},
// clang-format on
// add_gemm_wavelet_f16_tn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64}
});
});
}
}
void
insertTTProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
void
insertTTProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
{
{
v
.
insert
(
std
::
begin
(
v
),
v
.
insert
(
std
::
end
(
v
),
{
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true},
// clang-format off
// add_gemm_wavelet_f16_tt_256x256},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_wavelet_f16_tt_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_wavelet_f16_tt_256x128
},
LayoutConfig
{
true
,
true
,
true
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_wavelet_f16_tt_128x128
},
add_gemm_wavelet_f16_tt_256x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_wavelet_f16_tt_128x64
}
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true},
// clang-format on
// add_gemm_wavelet_f16_tt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64}
});
});
}
}
...
@@ -151,9 +142,7 @@ void get_problems(std::vector<ProblemDesc>& v, ABDataLayout layout)
...
@@ -151,9 +142,7 @@ void get_problems(std::vector<ProblemDesc>& v, ABDataLayout layout)
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
std
::
vector
<
ProblemDesc
>
problems
;
// std::vector<ProblemDesc> problems = {
// = {
// clang-format off
// clang-format off
// Use following if you run it on MI200 GPU
// Use following if you run it on MI200 GPU
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment