Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
98ef0abb
Unverified
Commit
98ef0abb
authored
Dec 05, 2023
by
Umang Yadav
Committed by
GitHub
Dec 05, 2023
Browse files
Device kernels using FP8 (#2510)
parent
6d0b6bcf
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
290 additions
and
168 deletions
+290
-168
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+10
-0
test/verify/gemm_2args_mm_8.cpp
test/verify/gemm_2args_mm_8.cpp
+1
-1
test/verify/gemm_add_broadcast2.cpp
test/verify/gemm_add_broadcast2.cpp
+1
-1
test/verify/test_arg_ops.cpp
test/verify/test_arg_ops.cpp
+204
-99
test/verify/test_contiguous.cpp
test/verify/test_contiguous.cpp
+6
-2
test/verify/test_logsoftmax.cpp
test/verify/test_logsoftmax.cpp
+4
-0
test/verify/test_multinomial.cpp
test/verify/test_multinomial.cpp
+9
-3
test/verify/test_nonzero.cpp
test/verify/test_nonzero.cpp
+7
-2
test/verify/test_nonzero_half.cpp
test/verify/test_nonzero_half.cpp
+0
-43
test/verify/test_prefix_scan_sum_2d.cpp
test/verify/test_prefix_scan_sum_2d.cpp
+15
-4
test/verify/test_reverse.cpp
test/verify/test_reverse.cpp
+7
-2
test/verify/test_rnn_sql_1.cpp
test/verify/test_rnn_sql_1.cpp
+11
-6
test/verify/test_scatter0.cpp
test/verify/test_scatter0.cpp
+8
-3
test/verify/test_topk_0.cpp
test/verify/test_topk_0.cpp
+7
-2
No files found.
src/targets/gpu/target.cpp
View file @
98ef0abb
...
@@ -110,6 +110,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -110,6 +110,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
{
{
unsupported_fp8_ops
.
insert
(
"dot"
);
unsupported_fp8_ops
.
insert
(
"dot"
);
}
}
// add all device kernels
unsupported_fp8_ops
.
insert
(
"logsoftmax"
);
unsupported_fp8_ops
.
insert
(
"nonzero"
);
unsupported_fp8_ops
.
insert
(
"prefix_scan_sum"
);
unsupported_fp8_ops
.
insert
(
"scatter_none"
);
unsupported_fp8_ops
.
insert
(
"topk"
);
unsupported_fp8_ops
.
insert
(
"rnn_var_sl_shift_output"
);
unsupported_fp8_ops
.
insert
(
"multinomial"
);
unsupported_fp8_ops
.
insert
(
"argmax"
);
unsupported_fp8_ops
.
insert
(
"argmin"
);
// clang-format off
// clang-format off
return
return
{
{
...
...
test/verify/gemm_2args_mm_8.cpp
View file @
98ef0abb
...
@@ -48,5 +48,5 @@ struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8<DType>>
...
@@ -48,5 +48,5 @@ struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8<DType>>
};
};
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
float_type
>;
// template struct gemm_2args_mm_8<migraphx::shape::half_type>;
// template struct gemm_2args_mm_8<migraphx::shape::half_type>;
// fails with CK, issue#2514
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_add_broadcast2.cpp
View file @
98ef0abb
...
@@ -51,5 +51,5 @@ struct gemm_add_broadcast2 : verify_program<gemm_add_broadcast2<DType>>
...
@@ -51,5 +51,5 @@ struct gemm_add_broadcast2 : verify_program<gemm_add_broadcast2<DType>>
};
};
template
struct
gemm_add_broadcast2
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_add_broadcast2
<
migraphx
::
shape
::
float_type
>;
// template struct gemm_add_broadcast2<migraphx::shape::half_type>;
// template struct gemm_add_broadcast2<migraphx::shape::half_type>;
// fails with CK, issue#2514
template
struct
gemm_add_broadcast2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
template
struct
gemm_add_broadcast2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_arg_ops.cpp
View file @
98ef0abb
This diff is collapsed.
Click to expand it.
test/verify/test_contiguous.cpp
View file @
98ef0abb
...
@@ -29,16 +29,20 @@
...
@@ -29,16 +29,20 @@
#include <cassert>
#include <cassert>
struct
test_contiguous
:
verify_program
<
test_contiguous
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_contiguous
:
verify_program
<
test_contiguous
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_t
ype
,
{
4
,
4
,
4
,
3
},
{
48
,
4
,
1
,
16
}};
migraphx
::
shape
s
{
DT
ype
,
{
4
,
4
,
4
,
3
},
{
48
,
4
,
1
,
16
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
assert
(
p
.
get_output_shapes
().
back
().
standard
());
assert
(
p
.
get_output_shapes
().
back
().
standard
());
return
p
;
return
p
;
}
}
};
};
template
struct
test_contiguous
<
migraphx
::
shape
::
float_type
>;
template
struct
test_contiguous
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_logsoftmax.cpp
View file @
98ef0abb
...
@@ -50,3 +50,7 @@ template struct test_logsoftmax<1, migraphx::shape::half_type>;
...
@@ -50,3 +50,7 @@ template struct test_logsoftmax<1, migraphx::shape::half_type>;
template
struct
test_logsoftmax
<
0
,
migraphx
::
shape
::
half_type
>;
template
struct
test_logsoftmax
<
0
,
migraphx
::
shape
::
half_type
>;
template
struct
test_logsoftmax
<
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_logsoftmax
<
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_logsoftmax
<
3
,
migraphx
::
shape
::
half_type
>;
template
struct
test_logsoftmax
<
3
,
migraphx
::
shape
::
half_type
>;
template
struct
test_logsoftmax
<
0
,
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
template
struct
test_logsoftmax
<
1
,
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
template
struct
test_logsoftmax
<
2
,
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
template
struct
test_logsoftmax
<
3
,
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_multinomial.cpp
View file @
98ef0abb
...
@@ -27,7 +27,8 @@
...
@@ -27,7 +27,8 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_multinomial
:
verify_program
<
test_multinomial
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_multinomial
:
verify_program
<
test_multinomial
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
...
@@ -40,10 +41,10 @@ struct test_multinomial : verify_program<test_multinomial>
...
@@ -40,10 +41,10 @@ struct test_multinomial : verify_program<test_multinomial>
std
::
uniform_real_distribution
<>
dis
(
0.0
,
1.0
);
std
::
uniform_real_distribution
<>
dis
(
0.0
,
1.0
);
std
::
vector
<
float
>
rand_samples
(
batch_size
*
sample_size
);
std
::
vector
<
float
>
rand_samples
(
batch_size
*
sample_size
);
std
::
generate
(
rand_samples
.
begin
(),
rand_samples
.
end
(),
[
&
]()
{
return
dis
(
gen
);
});
std
::
generate
(
rand_samples
.
begin
(),
rand_samples
.
end
(),
[
&
]()
{
return
dis
(
gen
);
});
migraphx
::
shape
rs
{
migraphx
::
shape
::
float_t
ype
,
{
batch_size
,
sample_size
}};
migraphx
::
shape
rs
{
DT
ype
,
{
batch_size
,
sample_size
}};
auto
rs_lit
=
mm
->
add_literal
(
migraphx
::
literal
{
rs
,
rand_samples
});
auto
rs_lit
=
mm
->
add_literal
(
migraphx
::
literal
{
rs
,
rand_samples
});
migraphx
::
shape
s
{
migraphx
::
shape
::
float_t
ype
,
{
batch_size
,
5
}};
migraphx
::
shape
s
{
DT
ype
,
{
batch_size
,
5
}};
auto
input
=
mm
->
add_parameter
(
"input"
,
s
);
auto
input
=
mm
->
add_parameter
(
"input"
,
s
);
auto
maxes
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_max"
,
{{
"axes"
,
{
1
}}}),
input
);
auto
maxes
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reduce_max"
,
{{
"axes"
,
{
1
}}}),
input
);
...
@@ -58,3 +59,8 @@ struct test_multinomial : verify_program<test_multinomial>
...
@@ -58,3 +59,8 @@ struct test_multinomial : verify_program<test_multinomial>
return
p
;
return
p
;
}
}
};
};
template
struct
test_multinomial
<
migraphx
::
shape
::
float_type
>;
template
struct
test_multinomial
<
migraphx
::
shape
::
half_type
>;
// This fails, need to figure out why
// template struct test_multinomial<migraphx::shape::fp8e4m3fnuz_type>;
test/verify/test_nonzero.cpp
View file @
98ef0abb
...
@@ -27,13 +27,14 @@
...
@@ -27,13 +27,14 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_nonzero
:
verify_program
<
test_nonzero
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_nonzero
:
verify_program
<
test_nonzero
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
s
{
DT
ype
,
{
2
,
3
,
4
,
5
}};
auto
x
=
mm
->
add_parameter
(
"data"
,
s
);
auto
x
=
mm
->
add_parameter
(
"data"
,
s
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonzero"
),
x
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonzero"
),
x
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
...
@@ -41,3 +42,7 @@ struct test_nonzero : verify_program<test_nonzero>
...
@@ -41,3 +42,7 @@ struct test_nonzero : verify_program<test_nonzero>
return
p
;
return
p
;
}
}
};
};
template
struct
test_nonzero
<
migraphx
::
shape
::
float_type
>;
template
struct
test_nonzero
<
migraphx
::
shape
::
half_type
>;
template
struct
test_nonzero
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_nonzero_half.cpp
deleted
100644 → 0
View file @
6d0b6bcf
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_nonzero_half
:
verify_program
<
test_nonzero_half
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
half_type
,
{
3
,
4
,
3
,
5
}};
auto
x
=
mm
->
add_parameter
(
"data"
,
s
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonzero"
),
x
);
mm
->
add_return
({
r
});
return
p
;
}
};
test/verify/test_prefix_scan_sum_2d.cpp
View file @
98ef0abb
...
@@ -23,16 +23,18 @@
...
@@ -23,16 +23,18 @@
*/
*/
#include "verify_program.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_prefix_scan_sum_2d_small
:
verify_program
<
test_prefix_scan_sum_2d_small
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_prefix_scan_sum_2d_small
:
verify_program
<
test_prefix_scan_sum_2d_small
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_t
ype
,
{
1
}};
migraphx
::
shape
s
{
DT
ype
,
{
1
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
xb
=
auto
xb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
3
}}}),
x
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
3
}}}),
x
);
...
@@ -42,16 +44,25 @@ struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_sm
...
@@ -42,16 +44,25 @@ struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_sm
}
}
};
};
struct
test_prefix_scan_sum_2d_large
:
verify_program
<
test_prefix_scan_sum_2d_large
>
template
struct
test_prefix_scan_sum_2d_small
<
migraphx
::
shape
::
float_type
>;
template
struct
test_prefix_scan_sum_2d_small
<
migraphx
::
shape
::
half_type
>;
template
struct
test_prefix_scan_sum_2d_small
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_prefix_scan_sum_2d_large
:
verify_program
<
test_prefix_scan_sum_2d_large
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
1000
}};
migraphx
::
shape
s
{
DT
ype
,
{
3
,
1000
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
mm
->
add_instruction
(
mm
->
add_instruction
(
migraphx
::
make_op
(
"prefix_scan_sum"
,
{{
"axis"
,
1
},
{
"exclusive"
,
false
}}),
x
);
migraphx
::
make_op
(
"prefix_scan_sum"
,
{{
"axis"
,
1
},
{
"exclusive"
,
false
}}),
x
);
return
p
;
return
p
;
}
}
};
};
template
struct
test_prefix_scan_sum_2d_large
<
migraphx
::
shape
::
float_type
>;
template
struct
test_prefix_scan_sum_2d_large
<
migraphx
::
shape
::
half_type
>;
template
struct
test_prefix_scan_sum_2d_large
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_reverse.cpp
View file @
98ef0abb
...
@@ -26,16 +26,21 @@
...
@@ -26,16 +26,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_reverse
:
verify_program
<
test_reverse
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_reverse
:
verify_program
<
test_reverse
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_t
ype
,
{
4
,
16
}};
migraphx
::
shape
s
{
DT
ype
,
{
4
,
16
}};
auto
a0
=
mm
->
add_parameter
(
"data"
,
s
);
auto
a0
=
mm
->
add_parameter
(
"data"
,
s
);
std
::
vector
<
int64_t
>
axis
=
{
0
};
std
::
vector
<
int64_t
>
axis
=
{
0
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"reverse"
,
{{
"axes"
,
axis
}}),
a0
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"reverse"
,
{{
"axes"
,
axis
}}),
a0
);
return
p
;
return
p
;
}
}
};
};
template
struct
test_reverse
<
migraphx
::
shape
::
float_type
>;
template
struct
test_reverse
<
migraphx
::
shape
::
half_type
>;
template
struct
test_reverse
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_rnn_sql_1.cpp
View file @
98ef0abb
...
@@ -31,7 +31,8 @@
...
@@ -31,7 +31,8 @@
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
struct
test_rnn_sql_1
:
verify_program
<
test_rnn_sql_1
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_rnn_sql_1
:
verify_program
<
test_rnn_sql_1
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
...
@@ -44,12 +45,12 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
...
@@ -44,12 +45,12 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_t
ype
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
in_shape
{
DT
ype
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_t
ype
,
{
num_dirct
,
hidden_size
,
input_size
}};
migraphx
::
shape
w_shape
{
DT
ype
,
{
num_dirct
,
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_t
ype
,
{
num_dirct
,
hidden_size
,
hidden_size
}};
migraphx
::
shape
r_shape
{
DT
ype
,
{
num_dirct
,
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_t
ype
,
{
num_dirct
,
2
*
hidden_size
}};
migraphx
::
shape
b_shape
{
DT
ype
,
{
num_dirct
,
2
*
hidden_size
}};
migraphx
::
shape
s_shape
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
migraphx
::
shape
s_shape
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_t
ype
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ih_shape
{
DT
ype
,
{
num_dirct
,
batch_size
,
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
...
@@ -81,3 +82,7 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
...
@@ -81,3 +82,7 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
}
}
std
::
string
section
()
const
{
return
"rnn"
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
};
template
struct
test_rnn_sql_1
<
migraphx
::
shape
::
float_type
>;
template
struct
test_rnn_sql_1
<
migraphx
::
shape
::
half_type
>;
template
struct
test_rnn_sql_1
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_scatter0.cpp
View file @
98ef0abb
...
@@ -27,16 +27,17 @@
...
@@ -27,16 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatter0
:
verify_program
<
test_scatter0
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_scatter0
:
verify_program
<
test_scatter0
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sd
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
3
}};
migraphx
::
shape
sd
{
DT
ype
,
{
3
,
3
}};
migraphx
::
shape
si
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
migraphx
::
shape
si
{
migraphx
::
shape
::
int32_type
,
{
2
,
3
}};
std
::
vector
<
int
>
vi
=
{
1
,
0
,
2
,
0
,
2
,
1
};
std
::
vector
<
int
>
vi
=
{
1
,
0
,
2
,
0
,
2
,
1
};
migraphx
::
shape
su
{
migraphx
::
shape
::
float_t
ype
,
{
2
,
3
}};
migraphx
::
shape
su
{
DT
ype
,
{
2
,
3
}};
auto
pd
=
mm
->
add_parameter
(
"data"
,
sd
);
auto
pd
=
mm
->
add_parameter
(
"data"
,
sd
);
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
{
si
,
vi
});
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
{
si
,
vi
});
...
@@ -47,3 +48,7 @@ struct test_scatter0 : verify_program<test_scatter0>
...
@@ -47,3 +48,7 @@ struct test_scatter0 : verify_program<test_scatter0>
return
p
;
return
p
;
}
}
};
};
template
struct
test_scatter0
<
migraphx
::
shape
::
float_type
>;
template
struct
test_scatter0
<
migraphx
::
shape
::
half_type
>;
template
struct
test_scatter0
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/test_topk_0.cpp
View file @
98ef0abb
...
@@ -27,13 +27,14 @@
...
@@ -27,13 +27,14 @@
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
struct
test_topk_0
:
verify_program
<
test_topk_0
>
template
<
migraphx
::
shape
::
type_t
DType
>
struct
test_topk_0
:
verify_program
<
test_topk_0
<
DType
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_t
ype
,
{
3
,
5
}};
migraphx
::
shape
s
{
DT
ype
,
{
3
,
5
}};
auto
data
=
mm
->
add_parameter
(
"data"
,
s
);
auto
data
=
mm
->
add_parameter
(
"data"
,
s
);
auto
r
=
mm
->
add_instruction
(
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"topk"
,
{{
"axis"
,
1
},
{
"k"
,
4
},
{
"largest"
,
1
}}),
data
);
migraphx
::
make_op
(
"topk"
,
{{
"axis"
,
1
},
{
"k"
,
4
},
{
"largest"
,
1
}}),
data
);
...
@@ -43,3 +44,7 @@ struct test_topk_0 : verify_program<test_topk_0>
...
@@ -43,3 +44,7 @@ struct test_topk_0 : verify_program<test_topk_0>
return
p
;
return
p
;
}
}
};
};
template
struct
test_topk_0
<
migraphx
::
shape
::
float_type
>;
template
struct
test_topk_0
<
migraphx
::
shape
::
half_type
>;
template
struct
test_topk_0
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment