Unverified Commit 3bc4a779 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into doc-standard

parents 3053fc95 6a72e8fc
...@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]}) ...@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
Optimize when reading Optimize when reading
.. option:: --apply-pass, -p
Passes to apply to model
.. option:: --graphviz, -g .. option:: --graphviz, -g
Print out a graphviz representation. Print out a graphviz representation.
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
add_executable(driver add_executable(driver
main.cpp main.cpp
verify.cpp verify.cpp
passes.cpp
perf.cpp perf.cpp
resnet50.cpp resnet50.cpp
inceptionv3.cpp inceptionv3.cpp
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "argument_parser.hpp" #include "argument_parser.hpp"
#include "command.hpp" #include "command.hpp"
#include "precision.hpp" #include "precision.hpp"
#include "passes.hpp"
#include "perf.hpp" #include "perf.hpp"
#include "models.hpp" #include "models.hpp"
#include "marker_roctx.hpp" #include "marker_roctx.hpp"
...@@ -83,6 +84,7 @@ struct loader ...@@ -83,6 +84,7 @@ struct loader
std::vector<std::string> param_dims; std::vector<std::string> param_dims;
std::vector<std::string> dyn_param_dims; std::vector<std::string> dyn_param_dims;
std::vector<std::string> output_names; std::vector<std::string> output_names;
std::vector<std::string> passes;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
...@@ -130,6 +132,7 @@ struct loader ...@@ -130,6 +132,7 @@ struct loader
ap.append(), ap.append(),
ap.nargs(2)); ap.nargs(2));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true)); ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
ap(passes, {"--apply-pass", "-p"}, ap.help("Passes to apply to model"), ap.append());
ap(output_type, ap(output_type,
{"--graphviz", "-g"}, {"--graphviz", "-g"},
ap.help("Print out a graphviz representation."), ap.help("Print out a graphviz representation."),
...@@ -337,6 +340,8 @@ struct loader ...@@ -337,6 +340,8 @@ struct loader
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
}); });
} }
if(not passes.empty())
migraphx::run_passes(*p.get_main_module(), get_passes(passes));
return p; return p;
} }
......
...@@ -22,23 +22,88 @@ ...@@ -22,23 +22,88 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "verify_program.hpp" #include "passes.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_conv_relu_half : verify_program<test_conv_relu_half> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/promote_literals.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_map<std::string, pass> create_passes_lookup()
{ {
migraphx::program create_program() const std::unordered_map<std::string, pass> result;
{ // clang-format off
migraphx::program p; std::initializer_list<pass> passes = {
auto* mm = p.get_main_module(); auto_contiguous{},
auto input = dead_code_elimination{},
mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); eliminate_allocation{},
auto weights = eliminate_common_subexpression{},
mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); eliminate_concat{},
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); eliminate_contiguous{},
mm->add_instruction(migraphx::make_op("relu"), conv); eliminate_data_type{},
return p; eliminate_identity{},
} eliminate_pad{},
}; inline_module{},
insert_pad{},
normalize_ops{},
optimize_module{},
promote_literals{},
propagate_constant{},
rewrite_gelu{},
rewrite_pooling{},
rewrite_quantization{},
rewrite_rnn{},
simplify_algebra{},
simplify_dyn_ops{},
simplify_qdq{},
simplify_reshapes{},
};
// clang-format on
for(const auto& pass : passes)
result[pass.name()] = pass;
result["eliminate_dead_code"] = dead_code_elimination{};
return result;
}
std::vector<pass> get_passes(const std::vector<std::string>& names)
{
std::vector<pass> result;
static const std::unordered_map<std::string, pass> lookup = create_passes_lookup();
std::transform(
names.begin(), names.end(), std::back_inserter(result), [](const std::string& name) {
if(not contains(lookup, name))
MIGRAPHX_THROW("Unknown pass: " + name);
return lookup.at(name);
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
...@@ -21,23 +21,20 @@ ...@@ -21,23 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#define MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#include "verify_program.hpp" #include <migraphx/pass.hpp>
#include <migraphx/program.hpp> #include <vector>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_nonzero_half : verify_program<test_nonzero_half> namespace migraphx {
{ namespace driver {
migraphx::program create_program() const inline namespace MIGRAPHX_INLINE_NS {
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {3, 4, 3, 5}};
auto x = mm->add_parameter("data", s);
auto r = mm->add_instruction(migraphx::make_op("nonzero"), x);
mm->add_return({r});
return p; std::vector<pass> get_passes(const std::vector<std::string>& names);
}
}; } // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
...@@ -31,6 +31,72 @@ ...@@ -31,6 +31,72 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void insert_convert_to_supported_type(module& m,
instruction_ref ins,
migraphx::shape::type_t target_type,
std::set<migraphx::shape::type_t> unsupported_types)
{
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](const auto& i) {
if(contains(unsupported_types, i->get_shape().type()))
{
return m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i);
}
else
{
return i;
}
});
// if no change
if(inputs == ins->inputs())
return;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto new_ins = m.insert_instruction(ins, op, inputs);
if(orig_type == shape::tuple_type)
{
auto orig_outs = ins->outputs();
if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) {
return out_ins->name() == "get_tuple_elem";
}))
MIGRAPHX_THROW(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction");
std::transform(
orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) {
auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins);
auto orig_out_type = out_ins->get_shape().type();
if(contains(unsupported_types, orig_out_type))
{
auto gte_convert = m.insert_instruction(
ins, make_op("convert", {{"target_type", orig_out_type}}), gte_ins);
return m.replace_instruction(out_ins, gte_convert);
}
else
{
return m.replace_instruction(out_ins, gte_ins);
}
});
}
else
{
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
}
void eliminate_data_type::apply(module& m) const void eliminate_data_type::apply(module& m) const
{ {
static const std::vector<std::string> skip_op_names = {"convert", static const std::vector<std::string> skip_op_names = {"convert",
...@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const ...@@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add", "scatternd_add",
"scatternd_mul", "scatternd_mul",
"scatternd_none"}; "scatternd_none"};
if(unsupported_types.empty())
return;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
continue; continue;
if(contains(skip_op_names, ins->name())) if(contains(skip_op_names, ins->name()) and not contains(unsupported_ops, ins->name()))
continue;
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if(types.count(i->get_shape().type()) == 0)
return i;
return m.insert_instruction(ins, make_op("convert", {{"target_type", target_type}}), i);
});
if(inputs == ins->inputs())
continue; continue;
auto op = ins->get_operator(); if(contains(unsupported_ops, "all") or contains(unsupported_ops, ins->name()))
auto attributes = op.attributes(); insert_convert_to_supported_type(m, ins, target_type, unsupported_types);
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto old_type = ins->get_shape().type();
auto out = m.insert_instruction(ins, op, inputs);
auto convert =
m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out);
m.replace_instruction(ins, convert);
} }
} }
......
...@@ -40,8 +40,9 @@ struct module; ...@@ -40,8 +40,9 @@ struct module;
*/ */
struct MIGRAPHX_EXPORT eliminate_data_type struct MIGRAPHX_EXPORT eliminate_data_type
{ {
std::set<shape::type_t> types; std::set<shape::type_t> unsupported_types;
shape::type_t target_type; shape::type_t target_type;
std::set<std::string> unsupported_ops = {"all"};
std::string name() const { return "eliminate_data_type"; } std::string name() const { return "eliminate_data_type"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/convolution.hpp> #include <migraphx/convolution.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
...@@ -87,11 +88,13 @@ struct quant_convolution ...@@ -87,11 +88,13 @@ struct quant_convolution
} }
// all input type must be int8_type and output is float_type // all input type must be int8_type and output is float_type
if(t != shape::int8_type) std::set<migraphx::shape::type_t> supported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(supported_types, t))
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t"); MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type");
} }
t = shape::int32_type;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size(); auto padding_size = padding.size();
...@@ -107,8 +110,11 @@ struct quant_convolution ...@@ -107,8 +110,11 @@ struct quant_convolution
stride[i] + stride[i] +
1))); 1)));
} }
if(t == shape::int8_type)
return inputs[0].with_lens(t, output_lens); {
return inputs[0].with_lens(shape::int32_type, output_lens);
} // else fp8 conv
return inputs[0].with_lens(shape::float_type, output_lens);
} }
size_t kdims() const size_t kdims() const
......
...@@ -259,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT ...@@ -259,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning # Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API) check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
...@@ -288,10 +290,18 @@ else() ...@@ -288,10 +290,18 @@ else()
message(STATUS "rocBLAS does not have User Tuning Beta API") message(STATUS "rocBLAS does not have User Tuning Beta API")
endif() endif()
if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have Fp8 Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL) if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library) target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif() endif()
add_subdirectory(driver) add_subdirectory(driver)
......
...@@ -49,6 +49,12 @@ std::string get_device_name() ...@@ -49,6 +49,12 @@ std::string get_device_name()
return props.gcnArchName; return props.gcnArchName;
} }
bool gfx_has_fp8_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return false; return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution") if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false; return false;
auto input_arg_t = ins->inputs().front()->get_shape().type();
value v = ins->get_operator().to_value(); value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>(); auto group = v.at("group").to<int>();
if(group != 1) if(group != 1)
...@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!" // Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4) if(ins->get_shape().lens().size() != 4)
return false; return false;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::float_type and input_arg_t == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::int8_type) if(ins->get_shape().type() == shape::int8_type)
return true; return true;
if(mode == mlir_mode::int8) if(mode == mlir_mode::int8)
...@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto result_type = i.get_shape().type(); const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type, const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type, type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::int8_type, type_t::int8_type,
type_t::int32_type, type_t::int32_type,
type_t::bool_type}; type_t::bool_type};
...@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax", "softmax",
"tanh", "tanh",
}; };
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); bool is_float =
contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name)) if(contains(any_type_ops, name))
return true; return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name)) if(result_type != type_t::bool_type and contains(no_bool_ops, name))
...@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
// supported. // supported.
if(is_float and name == "convert") if(is_float and name == "convert")
{ {
if(result_type == shape::fp8e4m3fnuz_type)
{
return false;
} // else
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
}); });
...@@ -404,12 +415,13 @@ struct find_mlir_standalone_op ...@@ -404,12 +415,13 @@ struct find_mlir_standalone_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto gemm_based_op = r.result; auto gemm_based_op = r.result;
// // enable only for fp32/fp16/i8/fp8 types
// enable only for fp32/fp16/i8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains( return not contains({shape::type_t::float_type,
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, shape::type_t::half_type,
i->get_shape().type()); shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type());
})) }))
return; return;
static size_t counter = 0; static size_t counter = 0;
...@@ -531,7 +543,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const ...@@ -531,7 +543,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match::find_matches( match::find_matches(
mpm, mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)}, find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)}); find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else #else
(void)mpm; (void)mpm;
......
...@@ -22,11 +22,14 @@ ...@@ -22,11 +22,14 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
#include <type_traits>
using microseconds = std::chrono::duration<double, std::micro>; using microseconds = std::chrono::duration<double, std::micro>;
...@@ -34,6 +37,20 @@ namespace migraphx { ...@@ -34,6 +37,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};
// Convert rocBLAS datatypes to equivalent Migraphx data types // Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type) rocblas_datatype get_type(shape::type_t type)
{ {
...@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type) ...@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r; case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r; case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r; case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type: case shape::fp8e4m3fnuz_type: return rocblas_datatype_f8_r;
case shape::tuple_type: case shape::tuple_type:
case shape::bool_type: case shape::bool_type:
case shape::uint16_type: case shape::uint16_type:
...@@ -183,12 +200,17 @@ struct gemm_impl ...@@ -183,12 +200,17 @@ struct gemm_impl
{ {
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
compute_type = output_type; compute_type = rb_compute_type{output_type};
if(compute_fp32) if(compute_fp32)
{ {
if(arg_type == rocblas_datatype_f16_r) if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
if(arg_type == rocblas_datatype_f8_r)
{
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}
auto a_lens = input_shapes[0].lens(); auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens(); auto b_lens = input_shapes[1].lens();
...@@ -217,23 +239,52 @@ struct gemm_impl ...@@ -217,23 +239,52 @@ struct gemm_impl
void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
{ {
if(strided_batched) #ifdef MIGRAPHX_USE_ROCBLAS_FP8_API
if(rocblas_fp8_available() and
std::any_of(input_args.begin(), input_args.end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{ {
auto common_args = create_strided_batched_args_common(ctx, input_args); if(strided_batched)
rocblas_invoke(&rocblas_gemm_strided_batched_ex, {
common_args, auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_gemm_algo_solution_index, rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
solution_idx, common_args,
gemm_flags); rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
} }
else else
#endif
{ {
auto common_args = create_gemm_ex_args_common(ctx, input_args); if(strided_batched)
rocblas_invoke(&rocblas_gemm_ex, {
common_args, auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_gemm_algo_solution_index, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
solution_idx, common_args,
gemm_flags); rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
} }
} }
...@@ -331,7 +382,6 @@ struct gemm_impl ...@@ -331,7 +382,6 @@ struct gemm_impl
num_matrices, num_matrices,
compute_type); compute_type);
} }
/** /**
* Helper method to create that subset of a long rocBLAS argument list that is common * Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls. * to multiple "gemm_ex..." calls.
...@@ -366,6 +416,7 @@ struct gemm_impl ...@@ -366,6 +416,7 @@ struct gemm_impl
ldd, ldd,
compute_type); compute_type);
} }
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/** /**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index * Find best rocBLAS solution: Get list of solutions and try them all, returning the index
...@@ -481,8 +532,8 @@ struct gemm_impl ...@@ -481,8 +532,8 @@ struct gemm_impl
rocblas_int b_stride = 0; rocblas_int b_stride = 0;
rocblas_int c_stride = 0; rocblas_int c_stride = 0;
rocblas_int d_stride = 0; rocblas_int d_stride = 0;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype arg_type = rocblas_datatype_f32_r; rocblas_datatype arg_type = rocblas_datatype_f32_r;
rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r; rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true; bool strided_batched = true;
bool is_3inputs = true; bool is_3inputs = true;
......
...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name(); ...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT int get_device_id(); MIGRAPHX_GPU_EXPORT int get_device_id();
MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics();
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -40,6 +40,8 @@ struct context; ...@@ -40,6 +40,8 @@ struct context;
MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag(); MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag();
MIGRAPHX_GPU_EXPORT bool rocblas_fp8_available();
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x) ...@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
return y; return y;
} }
template <unsigned int DppCtrl, template <class T, class F>
unsigned int RowMask = 0xf, __device__ T dpp_op(T& x, F f)
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{ {
static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4; static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type union type
...@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x) ...@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
input.data = x; input.data = x;
for(index_int i = 0; i < n; i++) for(index_int i = 0; i < n; i++)
{ {
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); output.reg[i] = f(input.reg[i]);
} }
return output.data; return output.data;
} }
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
return dpp_op(x,
[](auto i) { return __hip_move_dpp(i, DppCtrl, RowMask, BankMask, BoundCtrl); });
}
template <unsigned int Mask, class T>
__device__ T dpp_swizzle(T& x)
{
return dpp_op(x, [](auto i) { return __hip_ds_swizzle(i, Mask); });
}
#endif // MIGRAPHX_HAS_DPP #endif // MIGRAPHX_HAS_DPP
} // namespace migraphx } // namespace migraphx
......
...@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz> ...@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{ {
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
} }
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static constexpr __device__ fp8e5m2fnuz min() static constexpr __device__ fp8e5m2fnuz min()
{ {
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
...@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2> ...@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2>
} }
static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
......
...@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op)
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in); out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out); in = op(in, out);
#if __AMDGCN_WAVEFRONT_SIZE == 64 #if __AMDGCN_WAVEFRONT_SIZE == 32
out = dpp_swizzle<0x1e0>(in);
in = op(in, out);
#else
out = dpp_mov<dpp_row_bcast(15), 0xa>(in); out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in); out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
...@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op)
} }
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK) #if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1 #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
(void)f; \
x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64 #elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
...@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \ "s_nop 1\n" \
: "=v"(x) \ : "=v"(x) \
: "0"(x)) : "0"(x)); \
(void)f
#else #else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \ : "=v"(x) \
: "0"(x)) : "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif #endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \ #define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \ __device__ inline void dpp_reduce(double& x, op f) \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \ { \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); \
__device__ inline void dpp_reduce(int32_t& x, op) \ } \
{ \ __device__ inline void dpp_reduce(float& x, op f) \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \ { \
} \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); } } \
__device__ inline void dpp_reduce(half& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); \
} \
__device__ inline void dpp_reduce(int32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); \
}
// Note: when max and min are in int32_t, signed version of instruction needs to be used. // Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u) MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u)
...@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F> ...@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32 constexpr index_int lanes_per_thread = __AMDGCN_WAVEFRONT_SIZE;
constexpr index_int lanes_per_thread = 16;
#else
constexpr index_int lanes_per_thread = 64;
#endif
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = type(init); type x = type(init);
......
...@@ -300,6 +300,8 @@ struct mlir_program ...@@ -300,6 +300,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get()); result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type) else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get()); result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::double_type) else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get()); result = mlirF64TypeGet(ctx.get());
else if(as.is_integral()) else if(as.is_integral())
......
...@@ -53,6 +53,15 @@ bool get_compute_fp32_flag() ...@@ -53,6 +53,15 @@ bool get_compute_fp32_flag()
return (starts_with(device_name, "gfx9") and device_name >= "gfx908"); return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
} }
bool rocblas_fp8_available()
{
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false;
#else
return gfx_has_fp8_intrinsics();
#endif
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -105,6 +105,29 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -105,6 +105,29 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type); unsupported_types.erase(shape::type_t::tuple_type);
// whiltelist supported Ops for the FP8
std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling");
if(not gpu::gfx_has_fp8_intrinsics())
{
unsupported_fp8_ops.insert("convolution");
unsupported_fp8_ops.insert("quant_convolution");
}
// add all device kernels
unsupported_fp8_ops.insert("logsoftmax");
unsupported_fp8_ops.insert("nonzero");
unsupported_fp8_ops.insert("prefix_scan_sum");
unsupported_fp8_ops.insert("scatter_none");
unsupported_fp8_ops.insert("topk");
unsupported_fp8_ops.insert("rnn_var_sl_shift_output");
unsupported_fp8_ops.insert("multinomial");
unsupported_fp8_ops.insert("argmax");
unsupported_fp8_ops.insert("argmin");
// clang-format off // clang-format off
return return
{ {
...@@ -136,6 +159,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -136,6 +159,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops},
dead_code_elimination{},
optimize_module{}, optimize_module{},
fuse_pointwise{}, fuse_pointwise{},
dead_code_elimination{}, dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment