Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
67e84aed
Commit
67e84aed
authored
Sep 14, 2022
by
Po-Yen, Chen
Browse files
Remove cmd arg parsing logics
parent
234c0580
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
127 deletions
+41
-127
example/37_permute/common.hpp
example/37_permute/common.hpp
+0
-65
example/37_permute/permute_1xHxW_fp16.cpp
example/37_permute/permute_1xHxW_fp16.cpp
+1
-4
example/37_permute/permute_HxWx4_fp16.cpp
example/37_permute/permute_HxWx4_fp16.cpp
+1
-4
example/37_permute/permute_NxHxW_fp16.cpp
example/37_permute/permute_NxHxW_fp16.cpp
+1
-4
example/37_permute/run_permute_example.inc
example/37_permute/run_permute_example.inc
+38
-50
No files found.
example/37_permute/common.hpp
View file @
67e84aed
...
@@ -27,12 +27,6 @@ using F16 = ck::half_t;
...
@@ -27,12 +27,6 @@ using F16 = ck::half_t;
using
F32
=
float
;
using
F32
=
float
;
using
F64
=
double
;
using
F64
=
double
;
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
bool
time_kernel
=
true
;
};
struct
Problem
final
struct
Problem
final
{
{
static
constexpr
std
::
size_t
NumDim
=
3
;
static
constexpr
std
::
size_t
NumDim
=
3
;
...
@@ -273,65 +267,6 @@ is_valid_axes(const Axes& axes)
...
@@ -273,65 +267,6 @@ is_valid_axes(const Axes& axes)
(
*
std
::
prev
(
last
)
==
size
(
axes
)
-
1
);
(
*
std
::
prev
(
last
)
==
size
(
axes
)
-
1
);
}
}
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ExecutionConfig
&
config
,
Problem
&
problem
)
{
constexpr
int
num_execution_config_args
=
2
;
constexpr
int
num_problem_args
=
2
*
Problem
::
NumDim
;
if
(
!
(
num_problem_args
==
size
(
problem
.
shape
)
+
size
(
problem
.
axes
)))
{
return
false
;
}
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
1
+
num_execution_config_args
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
2
]);
}
else
if
(
argc
==
1
+
num_execution_config_args
+
num_problem_args
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
2
]);
// read shape
for
(
std
::
size_t
idx
=
0
;
idx
<
size
(
problem
.
shape
);
++
idx
)
{
problem
.
shape
[
idx
]
=
std
::
stoi
(
argv
[
idx
+
(
1
+
num_execution_config_args
)]);
}
// read axes
for
(
std
::
size_t
idx
=
0
;
idx
<
size
(
problem
.
axes
);
++
idx
)
{
problem
.
axes
[
idx
]
=
std
::
stoi
(
argv
[
idx
+
(
1
+
num_execution_config_args
+
size
(
problem
.
shape
))]);
}
if
(
!
is_valid_axes
(
problem
.
axes
))
{
std
::
cerr
<<
"invalid axes: "
;
std
::
copy
(
begin
(
problem
.
axes
),
end
(
problem
.
axes
),
std
::
ostream_iterator
<
std
::
size_t
>
(
std
::
cerr
,
" "
));
std
::
cerr
<<
std
::
endl
;
return
false
;
}
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3 ~ arg5: shape for 3D tensor"
<<
std
::
endl
<<
"arg6 ~ arg8: axes to permute"
<<
std
::
endl
;
return
false
;
}
return
true
;
}
template
<
typename
Shape
>
template
<
typename
Shape
>
inline
std
::
enable_if_t
<
detail
::
is_range_v
<
Shape
>
,
bool
>
is_valid_shape
(
const
Shape
&
shape
)
inline
std
::
enable_if_t
<
detail
::
is_range_v
<
Shape
>
,
bool
>
is_valid_shape
(
const
Shape
&
shape
)
{
{
...
...
example/37_permute/permute_1xHxW_fp16.cpp
View file @
67e84aed
...
@@ -17,7 +17,4 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
...
@@ -17,7 +17,4 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
#include "run_permute_example.inc"
#include "run_permute_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
int
main
()
{
return
!
run_permute_example
({
1
,
16000
,
80
},
{
0
,
2
,
1
});
}
{
return
!
run_permute_example
(
argc
,
argv
,
{
1
,
16000
,
80
},
{
0
,
2
,
1
});
}
example/37_permute/permute_HxWx4_fp16.cpp
View file @
67e84aed
...
@@ -20,7 +20,4 @@ static_assert(std::is_same_v<detail::get_bundled_t<F64, NUM_ELEMS_IN_BUNDLE>, F1
...
@@ -20,7 +20,4 @@ static_assert(std::is_same_v<detail::get_bundled_t<F64, NUM_ELEMS_IN_BUNDLE>, F1
#include "run_permute_example.inc"
#include "run_permute_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
int
main
()
{
return
!
run_permute_example
({
1
,
80
,
16000
},
{
0
,
2
,
1
});
}
{
return
!
run_permute_example
(
argc
,
argv
,
{
1
,
80
,
16000
},
{
0
,
2
,
1
});
}
example/37_permute/permute_NxHxW_fp16.cpp
View file @
67e84aed
...
@@ -17,7 +17,4 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
...
@@ -17,7 +17,4 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
#include "run_permute_example.inc"
#include "run_permute_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
int
main
()
{
return
!
run_permute_example
({
121
,
768
,
80
},
{
0
,
2
,
1
});
}
{
return
!
run_permute_example
(
argc
,
argv
,
{
121
,
768
,
80
},
{
0
,
2
,
1
});
}
example/37_permute/run_permute_example.inc
View file @
67e84aed
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
#define NUM_ELEMS_IN_BUNDLE 1
#define NUM_ELEMS_IN_BUNDLE 1
#endif
#endif
bool
run_permute
(
const
ExecutionConfig
&
config
,
const
Problem
&
problem
)
bool
run_permute
(
const
Problem
&
problem
)
{
{
#if 1 < NUM_ELEMS_IN_BUNDLE
#if 1 < NUM_ELEMS_IN_BUNDLE
static_assert
(
std
::
is_same_v
<
ADataType
,
BDataType
>
&&
static_assert
(
std
::
is_same_v
<
ADataType
,
BDataType
>
&&
...
@@ -61,64 +61,52 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
...
@@ -61,64 +61,52 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
};
};
auto
invoker
=
permute
.
MakeInvoker
();
auto
invoker
=
permute
.
MakeInvoker
();
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
true
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
if
(
config
.
do_verification
)
b_device_buf
.
FromDevice
(
data
(
b
.
mData
));
{
b_device_buf
.
FromDevice
(
data
(
b
.
mData
));
#if NUM_ELEMS_IN_BUNDLE == 1
#if NUM_ELEMS_IN_BUNDLE == 1
Tensor
<
BDataType
>
host_b
(
transposed_shape
);
Tensor
<
BDataType
>
host_b
(
transposed_shape
);
if
(
!
host_permute
(
a
,
problem
.
axes
,
PassThrough
{},
host_b
))
if
(
!
host_permute
(
a
,
problem
.
axes
,
PassThrough
{},
host_b
))
{
{
return
false
;
return
false
;
}
}
return
ck
::
utils
::
check_err
(
return
ck
::
utils
::
check_err
(
b
.
mData
,
host_b
.
mData
,
"Error: incorrect results in output tensor"
,
1
e
-
6
,
1
e
-
6
);
b
.
mData
,
host_b
.
mData
,
"Error: incorrect results in output tensor"
,
1
e
-
6
,
1
e
-
6
);
#else
#else
// extend tensor shape from [N, H, W] to [N, H, W, NUM_ELEMS_IN_BUNDLE]
// extend tensor shape from [N, H, W] to [N, H, W, NUM_ELEMS_IN_BUNDLE]
using
DataType
=
detail
::
get_bundled_t
<
ADataType
,
NUM_ELEMS_IN_BUNDLE
>
;
using
DataType
=
detail
::
get_bundled_t
<
ADataType
,
NUM_ELEMS_IN_BUNDLE
>
;
const
auto
extended_shape
=
extend_shape
(
shape
,
NUM_ELEMS_IN_BUNDLE
);
const
auto
extended_shape
=
extend_shape
(
shape
,
NUM_ELEMS_IN_BUNDLE
);
const
auto
extended_axes
=
extend_axes
(
problem
.
axes
);
const
auto
extended_axes
=
extend_axes
(
problem
.
axes
);
ck
::
remove_cvref_t
<
decltype
(
extended_shape
)
>
transposed_extended_shape
;
ck
::
remove_cvref_t
<
decltype
(
extended_shape
)
>
transposed_extended_shape
;
transpose_shape
(
extended_shape
,
extended_axes
,
begin
(
transposed_extended_shape
));
transpose_shape
(
extended_shape
,
extended_axes
,
begin
(
transposed_extended_shape
));
Tensor
<
DataType
>
extended_a
(
extended_shape
);
Tensor
<
DataType
>
extended_a
(
extended_shape
);
std
::
memcpy
(
data
(
extended_a
.
mData
),
std
::
memcpy
(
data
(
a
.
mData
),
data
(
extended_a
.
mData
),
data
(
a
.
mData
),
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
Tensor
<
DataType
>
extended_host_b
(
transposed_extended_shape
);
Tensor
<
DataType
>
extended_host_b
(
transposed_extended_shape
);
if
(
!
host_permute
(
extended_a
,
extended_axes
,
PassThrough
{},
extended_host_b
))
if
(
!
host_permute
(
extended_a
,
extended_axes
,
PassThrough
{},
extended_host_b
))
{
{
return
false
;
return
false
;
}
return
ck
::
utils
::
check_err
(
ck
::
span
<
const
DataType
>
{
reinterpret_cast
<
DataType
*>
(
data
(
b
.
mData
)),
b
.
mDesc
.
GetElementSpaceSize
()
*
NUM_ELEMS_IN_BUNDLE
},
ck
::
span
<
const
DataType
>
{
extended_host_b
.
mData
},
"Error: incorrect results in output tensor"
,
1
e
-
6
,
1
e
-
6
);
#endif
}
}
return
true
;
return
ck
::
utils
::
check_err
(
ck
::
span
<
const
DataType
>
{
reinterpret_cast
<
DataType
*>
(
data
(
b
.
mData
)),
b
.
mDesc
.
GetElementSpaceSize
()
*
NUM_ELEMS_IN_BUNDLE
},
ck
::
span
<
const
DataType
>
{
extended_host_b
.
mData
},
"Error: incorrect results in output tensor"
,
1
e
-
6
,
1
e
-
6
);
#endif
}
}
bool
run_permute_example
(
int
argc
,
bool
run_permute_example
(
const
Problem
::
Shape
&
default_shape
,
const
Problem
::
Axes
&
default_axes
)
char
*
argv
[],
const
Problem
::
Shape
&
default_shape
,
const
Problem
::
Axes
&
default_axes
)
{
{
ExecutionConfig
config
;
return
run_permute
(
Problem
{
default_shape
,
default_axes
});
Problem
problem
(
default_shape
,
default_axes
);
return
parse_cmd_args
(
argc
,
argv
,
config
,
problem
)
&&
run_permute
(
config
,
problem
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment