Commit acd9bd3e authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'develop' into rocblas_mlir_fp8

parents b2542239 a09dc502
...@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]}) ...@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
Optimize when reading Optimize when reading
.. option:: --apply-pass, -p
Passes to apply to model
.. option:: --graphviz, -g .. option:: --graphviz, -g
Print out a graphviz representation. Print out a graphviz representation.
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
add_executable(driver add_executable(driver
main.cpp main.cpp
verify.cpp verify.cpp
passes.cpp
perf.cpp perf.cpp
resnet50.cpp resnet50.cpp
inceptionv3.cpp inceptionv3.cpp
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "argument_parser.hpp" #include "argument_parser.hpp"
#include "command.hpp" #include "command.hpp"
#include "precision.hpp" #include "precision.hpp"
#include "passes.hpp"
#include "perf.hpp" #include "perf.hpp"
#include "models.hpp" #include "models.hpp"
#include "marker_roctx.hpp" #include "marker_roctx.hpp"
...@@ -83,6 +84,7 @@ struct loader ...@@ -83,6 +84,7 @@ struct loader
std::vector<std::string> param_dims; std::vector<std::string> param_dims;
std::vector<std::string> dyn_param_dims; std::vector<std::string> dyn_param_dims;
std::vector<std::string> output_names; std::vector<std::string> output_names;
std::vector<std::string> passes;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
...@@ -130,6 +132,7 @@ struct loader ...@@ -130,6 +132,7 @@ struct loader
ap.append(), ap.append(),
ap.nargs(2)); ap.nargs(2));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true)); ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
ap(passes, {"--apply-pass", "-p"}, ap.help("Passes to apply to model"), ap.append());
ap(output_type, ap(output_type,
{"--graphviz", "-g"}, {"--graphviz", "-g"},
ap.help("Print out a graphviz representation."), ap.help("Print out a graphviz representation."),
...@@ -337,6 +340,8 @@ struct loader ...@@ -337,6 +340,8 @@ struct loader
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
}); });
} }
if(not passes.empty())
migraphx::run_passes(*p.get_main_module(), get_passes(passes));
return p; return p;
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "passes.hpp"
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/promote_literals.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_map<std::string, pass> create_passes_lookup()
{
std::unordered_map<std::string, pass> result;
// clang-format off
std::initializer_list<pass> passes = {
auto_contiguous{},
dead_code_elimination{},
eliminate_allocation{},
eliminate_common_subexpression{},
eliminate_concat{},
eliminate_contiguous{},
eliminate_data_type{},
eliminate_identity{},
eliminate_pad{},
inline_module{},
insert_pad{},
normalize_ops{},
optimize_module{},
promote_literals{},
propagate_constant{},
rewrite_gelu{},
rewrite_pooling{},
rewrite_quantization{},
rewrite_rnn{},
simplify_algebra{},
simplify_dyn_ops{},
simplify_qdq{},
simplify_reshapes{},
};
// clang-format on
for(const auto& pass : passes)
result[pass.name()] = pass;
result["eliminate_dead_code"] = dead_code_elimination{};
return result;
}
std::vector<pass> get_passes(const std::vector<std::string>& names)
{
std::vector<pass> result;
static const std::unordered_map<std::string, pass> lookup = create_passes_lookup();
std::transform(
names.begin(), names.end(), std::back_inserter(result), [](const std::string& name) {
if(not contains(lookup, name))
MIGRAPHX_THROW("Unknown pass: " + name);
return lookup.at(name);
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
...@@ -21,23 +21,20 @@ ...@@ -21,23 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#define MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#include "verify_program.hpp" #include <migraphx/pass.hpp>
#include <migraphx/program.hpp> #include <vector>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_nonzero_half : verify_program<test_nonzero_half> namespace migraphx {
{ namespace driver {
migraphx::program create_program() const inline namespace MIGRAPHX_INLINE_NS {
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {3, 4, 3, 5}};
auto x = mm->add_parameter("data", s);
auto r = mm->add_instruction(migraphx::make_op("nonzero"), x);
mm->add_return({r});
return p; std::vector<pass> get_passes(const std::vector<std::string>& names);
}
}; } // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
...@@ -301,6 +301,7 @@ target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) ...@@ -301,6 +301,7 @@ target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL) if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library) target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif() endif()
add_subdirectory(driver) add_subdirectory(driver)
......
...@@ -543,7 +543,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const ...@@ -543,7 +543,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match::find_matches( match::find_matches(
mpm, mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::all)}, find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)}); find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else #else
(void)mpm; (void)mpm;
......
...@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x) ...@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
return y; return y;
} }
template <unsigned int DppCtrl, template <class T, class F>
unsigned int RowMask = 0xf, __device__ T dpp_op(T& x, F f)
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{ {
static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4; static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type union type
...@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x) ...@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
input.data = x; input.data = x;
for(index_int i = 0; i < n; i++) for(index_int i = 0; i < n; i++)
{ {
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); output.reg[i] = f(input.reg[i]);
} }
return output.data; return output.data;
} }
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
return dpp_op(x,
[](auto i) { return __hip_move_dpp(i, DppCtrl, RowMask, BankMask, BoundCtrl); });
}
template <unsigned int Mask, class T>
__device__ T dpp_swizzle(T& x)
{
return dpp_op(x, [](auto i) { return __hip_ds_swizzle(i, Mask); });
}
#endif // MIGRAPHX_HAS_DPP #endif // MIGRAPHX_HAS_DPP
} // namespace migraphx } // namespace migraphx
......
...@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op)
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in); out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out); in = op(in, out);
#if __AMDGCN_WAVEFRONT_SIZE == 64 #if __AMDGCN_WAVEFRONT_SIZE == 32
out = dpp_swizzle<0x1e0>(in);
in = op(in, out);
#else
out = dpp_mov<dpp_row_bcast(15), 0xa>(in); out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in); out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
...@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op)
} }
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK) #if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1 #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
(void)f; \
x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64 #elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
...@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \ "s_nop 1\n" \
: "=v"(x) \ : "=v"(x) \
: "0"(x)) : "0"(x)); \
(void)f
#else #else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \ : "=v"(x) \
: "0"(x)) : "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif #endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \ #define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \ __device__ inline void dpp_reduce(double& x, op f) \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \ { \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); \
__device__ inline void dpp_reduce(int32_t& x, op) \ } \
{ \ __device__ inline void dpp_reduce(float& x, op f) \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \ { \
} \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); } } \
__device__ inline void dpp_reduce(half& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); \
} \
__device__ inline void dpp_reduce(int32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); \
}
// Note: when max and min are in int32_t, signed version of instruction needs to be used. // Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u) MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u)
...@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F> ...@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32 constexpr index_int lanes_per_thread = __AMDGCN_WAVEFRONT_SIZE;
constexpr index_int lanes_per_thread = 16;
#else
constexpr index_int lanes_per_thread = 64;
#endif
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = type(init); type x = type(init);
......
...@@ -118,6 +118,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -118,6 +118,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_fp8_ops.insert("convolution"); unsupported_fp8_ops.insert("convolution");
unsupported_fp8_ops.insert("quant_convolution"); unsupported_fp8_ops.insert("quant_convolution");
} }
// add all device kernels
unsupported_fp8_ops.insert("logsoftmax");
unsupported_fp8_ops.insert("nonzero");
unsupported_fp8_ops.insert("prefix_scan_sum");
unsupported_fp8_ops.insert("scatter_none");
unsupported_fp8_ops.insert("topk");
unsupported_fp8_ops.insert("rnn_var_sl_shift_output");
unsupported_fp8_ops.insert("multinomial");
unsupported_fp8_ops.insert("argmax");
unsupported_fp8_ops.insert("argmin");
// clang-format off // clang-format off
return return
{ {
......
...@@ -48,5 +48,5 @@ struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8<DType>> ...@@ -48,5 +48,5 @@ struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8<DType>>
}; };
template struct gemm_2args_mm_8<migraphx::shape::float_type>; template struct gemm_2args_mm_8<migraphx::shape::float_type>;
// template struct gemm_2args_mm_8<migraphx::shape::half_type>; // template struct gemm_2args_mm_8<migraphx::shape::half_type>; // fails with CK, issue#2514
template struct gemm_2args_mm_8<migraphx::shape::fp8e4m3fnuz_type>; template struct gemm_2args_mm_8<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -51,5 +51,5 @@ struct gemm_add_broadcast2 : verify_program<gemm_add_broadcast2<DType>> ...@@ -51,5 +51,5 @@ struct gemm_add_broadcast2 : verify_program<gemm_add_broadcast2<DType>>
}; };
template struct gemm_add_broadcast2<migraphx::shape::float_type>; template struct gemm_add_broadcast2<migraphx::shape::float_type>;
// template struct gemm_add_broadcast2<migraphx::shape::half_type>; // template struct gemm_add_broadcast2<migraphx::shape::half_type>; // fails with CK, issue#2514
template struct gemm_add_broadcast2<migraphx::shape::fp8e4m3fnuz_type>; template struct gemm_add_broadcast2<migraphx::shape::fp8e4m3fnuz_type>;
This diff is collapsed.
...@@ -29,16 +29,20 @@ ...@@ -29,16 +29,20 @@
#include <cassert> #include <cassert>
struct test_contiguous : verify_program<test_contiguous> template <migraphx::shape::type_t DType>
struct test_contiguous : verify_program<test_contiguous<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; migraphx::shape s{DType, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("contiguous"), x); mm->add_instruction(migraphx::make_op("contiguous"), x);
assert(p.get_output_shapes().back().standard()); assert(p.get_output_shapes().back().standard());
return p; return p;
} }
}; };
template struct test_contiguous<migraphx::shape::float_type>;
template struct test_contiguous<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -50,3 +50,7 @@ template struct test_logsoftmax<1, migraphx::shape::half_type>; ...@@ -50,3 +50,7 @@ template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>; template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>; template struct test_logsoftmax<2, migraphx::shape::half_type>;
template struct test_logsoftmax<3, migraphx::shape::half_type>; template struct test_logsoftmax<3, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_logsoftmax<1, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_logsoftmax<2, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_logsoftmax<3, migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,7 +27,8 @@ ...@@ -27,7 +27,8 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_multinomial : verify_program<test_multinomial> template <migraphx::shape::type_t DType>
struct test_multinomial : verify_program<test_multinomial<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -40,10 +41,10 @@ struct test_multinomial : verify_program<test_multinomial> ...@@ -40,10 +41,10 @@ struct test_multinomial : verify_program<test_multinomial>
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(batch_size * sample_size); std::vector<float> rand_samples(batch_size * sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {batch_size, sample_size}}; migraphx::shape rs{DType, {batch_size, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples}); auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
migraphx::shape s{migraphx::shape::float_type, {batch_size, 5}}; migraphx::shape s{DType, {batch_size, 5}};
auto input = mm->add_parameter("input", s); auto input = mm->add_parameter("input", s);
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input); auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
...@@ -58,3 +59,8 @@ struct test_multinomial : verify_program<test_multinomial> ...@@ -58,3 +59,8 @@ struct test_multinomial : verify_program<test_multinomial>
return p; return p;
} }
}; };
template struct test_multinomial<migraphx::shape::float_type>;
template struct test_multinomial<migraphx::shape::half_type>;
// This fails, need to figure out why
// template struct test_multinomial<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,13 +27,14 @@ ...@@ -27,13 +27,14 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_nonzero : verify_program<test_nonzero> template <migraphx::shape::type_t DType>
struct test_nonzero : verify_program<test_nonzero<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s{DType, {2, 3, 4, 5}};
auto x = mm->add_parameter("data", s); auto x = mm->add_parameter("data", s);
auto r = mm->add_instruction(migraphx::make_op("nonzero"), x); auto r = mm->add_instruction(migraphx::make_op("nonzero"), x);
mm->add_return({r}); mm->add_return({r});
...@@ -41,3 +42,7 @@ struct test_nonzero : verify_program<test_nonzero> ...@@ -41,3 +42,7 @@ struct test_nonzero : verify_program<test_nonzero>
return p; return p;
} }
}; };
template struct test_nonzero<migraphx::shape::float_type>;
template struct test_nonzero<migraphx::shape::half_type>;
template struct test_nonzero<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -23,16 +23,18 @@ ...@@ -23,16 +23,18 @@
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_small> template <migraphx::shape::type_t DType>
struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_small<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1}}; migraphx::shape s{DType, {1}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto xb = auto xb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), x); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), x);
...@@ -42,16 +44,25 @@ struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_sm ...@@ -42,16 +44,25 @@ struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_sm
} }
}; };
struct test_prefix_scan_sum_2d_large : verify_program<test_prefix_scan_sum_2d_large> template struct test_prefix_scan_sum_2d_small<migraphx::shape::float_type>;
template struct test_prefix_scan_sum_2d_small<migraphx::shape::half_type>;
template struct test_prefix_scan_sum_2d_small<migraphx::shape::fp8e4m3fnuz_type>;
template <migraphx::shape::type_t DType>
struct test_prefix_scan_sum_2d_large : verify_program<test_prefix_scan_sum_2d_large<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 1000}}; migraphx::shape s{DType, {3, 1000}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), x); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), x);
return p; return p;
} }
}; };
template struct test_prefix_scan_sum_2d_large<migraphx::shape::float_type>;
template struct test_prefix_scan_sum_2d_large<migraphx::shape::half_type>;
template struct test_prefix_scan_sum_2d_large<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>> ...@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
}; };
template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>;
...@@ -60,6 +62,9 @@ template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::sh ...@@ -60,6 +62,9 @@ template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::sh
template struct test_reduce_op_small<migraphx::op::reduce_sum, template struct test_reduce_op_small<migraphx::op::reduce_sum,
2, 2,
migraphx::shape::fp8e4m3fnuz_type>; migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum,
3,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, template struct test_reduce_op_small<migraphx::op::reduce_mean,
2, 2,
migraphx::shape::fp8e4m3fnuz_type>; migraphx::shape::fp8e4m3fnuz_type>;
......
...@@ -26,16 +26,21 @@ ...@@ -26,16 +26,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_reverse : verify_program<test_reverse> template <migraphx::shape::type_t DType>
struct test_reverse : verify_program<test_reverse<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 16}}; migraphx::shape s{DType, {4, 16}};
auto a0 = mm->add_parameter("data", s); auto a0 = mm->add_parameter("data", s);
std::vector<int64_t> axis = {0}; std::vector<int64_t> axis = {0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axis}}), a0); mm->add_instruction(migraphx::make_op("reverse", {{"axes", axis}}), a0);
return p; return p;
} }
}; };
template struct test_reverse<migraphx::shape::float_type>;
template struct test_reverse<migraphx::shape::half_type>;
template struct test_reverse<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment