Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2d6fe2cd
Commit
2d6fe2cd
authored
Feb 16, 2023
by
Adam Osewski
Browse files
Add argument to choose input layout.
parent
fc1edc32
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
125 additions
and
30 deletions
+125
-30
test/gemm/gemm_standalone_xdl_fp16.cpp
test/gemm/gemm_standalone_xdl_fp16.cpp
+125
-30
No files found.
test/gemm/gemm_standalone_xdl_fp16.cpp
View file @
2d6fe2cd
...
@@ -54,20 +54,110 @@ struct LayoutConfig
...
@@ -54,20 +54,110 @@ struct LayoutConfig
bool
CRowMajor
;
bool
CRowMajor
;
};
};
enum
class
ABDataLayout
:
int
{
NN
,
NT
,
TN
,
TT
,
ALL
,
};
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
// them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it
// upon invocation.
// And since DeviceGemm does not expose template arg information, an extra book keeping class
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to.
using
OpFactoryFn
=
void
(
*
)(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
);
using
ProblemDesc
=
std
::
tuple
<
GemmParams
,
LayoutConfig
,
OpFactoryFn
>
;
void
insertNNProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
{
v
.
insert
(
std
::
begin
(
v
),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true},
// add_gemm_wavelet_f16_nn_256x256},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_wavelet_f16_nn_256x128
},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true},
// add_gemm_wavelet_f16_nn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64}
});
}
void
insertNTProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
{
v
.
insert
(
std
::
begin
(
v
),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true},
// add_gemm_wavelet_f16_nt_256x256},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_wavelet_f16_nt_256x128
},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true},
// add_gemm_wavelet_f16_nt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64}
});
}
void
insertTNProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
{
v
.
insert
(
std
::
begin
(
v
),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true},
// add_gemm_wavelet_f16_tn_256x256},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_wavelet_f16_tn_256x128
},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true},
// add_gemm_wavelet_f16_tn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64}
});
}
void
insertTTProblems
(
std
::
vector
<
ProblemDesc
>&
v
)
{
v
.
insert
(
std
::
begin
(
v
),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true},
// add_gemm_wavelet_f16_tt_256x256},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_wavelet_f16_tt_256x128
},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true},
// add_gemm_wavelet_f16_tt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64}
});
}
void
get_problems
(
std
::
vector
<
ProblemDesc
>&
v
,
ABDataLayout
layout
)
{
switch
(
layout
)
{
case
ABDataLayout
::
NN
:
insertNNProblems
(
v
);
break
;
case
ABDataLayout
::
NT
:
insertNTProblems
(
v
);
break
;
case
ABDataLayout
::
TN
:
insertTNProblems
(
v
);
break
;
case
ABDataLayout
::
TT
:
insertTTProblems
(
v
);
break
;
case
ABDataLayout
::
ALL
:
default:
insertNNProblems
(
v
);
insertNTProblems
(
v
);
insertTNProblems
(
v
);
insertTTProblems
(
v
);
};
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
// them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it
std
::
vector
<
ProblemDesc
>
problems
;
// upon invocation.
// And since DeviceGemm does not expose template arg information, an extra book keeping class
// = {
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to.
// clang-format off
using
OpFactoryFn
=
void
(
*
)(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
);
std
::
vector
<
std
::
tuple
<
GemmParams
,
LayoutConfig
,
OpFactoryFn
>>
problems
=
{
// clang-format off
// Use following if you run it on MI200 GPU
// Use following if you run it on MI200 GPU
// 104 tiles
// 104 tiles
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
...
@@ -87,22 +177,22 @@ int main(int argc, char* argv[])
...
@@ -87,22 +177,22 @@ int main(int argc, char* argv[])
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// wavelet
// wavelet
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_wavelet_f16_nn_256x256
},
//
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x256},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_wavelet_f16_nn_256x128
},
//
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x128},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_wavelet_f16_nn_128x128
},
//
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x128},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_wavelet_f16_nn_128x64
},
//
{GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_wavelet_f16_nt_256x256
},
//
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x256},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_wavelet_f16_nt_256x128
},
//
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x128},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_wavelet_f16_nt_128x128
},
//
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x128},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_wavelet_f16_nt_128x64
},
//
{GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_wavelet_f16_tn_256x256
},
//
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_wavelet_f16_tn_256x128
},
//
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x128},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_wavelet_f16_tn_128x128
},
//
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_wavelet_f16_tn_128x64
},
//
{GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_wavelet_f16_tt_256x256
},
//
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x256},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_wavelet_f16_tt_256x128
},
//
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x128},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_wavelet_f16_tt_128x128
},
//
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x128},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_wavelet_f16_tt_128x64
},
//
{GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64},
// 110 tiles
// 110 tiles
...
@@ -122,28 +212,33 @@ int main(int argc, char* argv[])
...
@@ -122,28 +212,33 @@ int main(int argc, char* argv[])
// {GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
// {GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
// {GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
// {GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
// {GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// {GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// clang-format on
// clang-format on
};
//
};
bool
do_verification
=
true
;
bool
do_verification
=
true
;
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
auto
input_layout
=
ABDataLayout
::
ALL
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
// use default
// use default
}
}
else
if
(
argc
==
3
)
else
if
(
argc
==
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
time_kernel
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
2
]);
input_layout
=
ABDataLayout
{
std
::
stoi
(
argv
[
3
])};
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: time kernel (0=no, 1=yes)"
<<
std
::
endl
;
<<
"arg2: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: Input data layout (0=NN, 1=NT, 2=TN, 3=TT)"
<<
std
::
endl
;
return
0
;
return
0
;
}
}
get_problems
(
problems
,
input_layout
);
bool
pass
=
true
;
bool
pass
=
true
;
for
(
auto
&
p
:
problems
)
for
(
auto
&
p
:
problems
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment