Commit a98d86d9 authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'rocblas_fp8' into rocblas_mlir_fp8

parents a3d4b013 7e80f627
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
...@@ -36,6 +37,20 @@ namespace migraphx { ...@@ -36,6 +37,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};
// Convert rocBLAS datatypes to equivalent Migraphx data types // Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type) rocblas_datatype get_type(shape::type_t type)
{ {
...@@ -185,12 +200,17 @@ struct gemm_impl ...@@ -185,12 +200,17 @@ struct gemm_impl
{ {
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
compute_type = output_type; compute_type = rb_compute_type{output_type};
if(compute_fp32) if(compute_fp32)
{ {
if(arg_type == rocblas_datatype_f16_r) if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
if(arg_type == rocblas_datatype_f8_r)
{
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}
auto a_lens = input_shapes[0].lens(); auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens(); auto b_lens = input_shapes[1].lens();
...@@ -230,7 +250,6 @@ struct gemm_impl ...@@ -230,7 +250,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3, rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args, common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -240,7 +259,6 @@ struct gemm_impl ...@@ -240,7 +259,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3, rocblas_invoke(&rocblas_gemm_ex3,
common_args, common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -254,7 +272,6 @@ struct gemm_impl ...@@ -254,7 +272,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -264,7 +281,6 @@ struct gemm_impl ...@@ -264,7 +281,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex, rocblas_invoke(&rocblas_gemm_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -304,7 +320,6 @@ struct gemm_impl ...@@ -304,7 +320,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex, check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
rocblas_gemm_flags_check_solution_index); rocblas_gemm_flags_check_solution_index);
...@@ -314,7 +329,6 @@ struct gemm_impl ...@@ -314,7 +329,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex, check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
rocblas_gemm_flags_check_solution_index); rocblas_gemm_flags_check_solution_index);
...@@ -365,7 +379,8 @@ struct gemm_impl ...@@ -365,7 +379,8 @@ struct gemm_impl
output_type, output_type,
ldd, ldd,
d_stride, d_stride,
num_matrices); num_matrices,
compute_type);
} }
/** /**
* Helper method to create that subset of a long rocBLAS argument list that is common * Helper method to create that subset of a long rocBLAS argument list that is common
...@@ -398,7 +413,8 @@ struct gemm_impl ...@@ -398,7 +413,8 @@ struct gemm_impl
ldc, ldc,
is_3inputs ? args[3].data() : args[2].data(), is_3inputs ? args[3].data() : args[2].data(),
output_type, output_type,
ldd); ldd,
compute_type);
} }
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
...@@ -428,7 +444,6 @@ struct gemm_impl ...@@ -428,7 +444,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
nullptr, nullptr,
...@@ -438,7 +453,6 @@ struct gemm_impl ...@@ -438,7 +453,6 @@ struct gemm_impl
auto common_sol_args = create_strided_batched_args_common(ctx, input_args); auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args, common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
solution_indices.data(), solution_indices.data(),
...@@ -449,7 +463,6 @@ struct gemm_impl ...@@ -449,7 +463,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions, rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
nullptr, nullptr,
...@@ -459,7 +472,6 @@ struct gemm_impl ...@@ -459,7 +472,6 @@ struct gemm_impl
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args); auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions, rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args, common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
solution_indices.data(), solution_indices.data(),
...@@ -521,7 +533,7 @@ struct gemm_impl ...@@ -521,7 +533,7 @@ struct gemm_impl
rocblas_int c_stride = 0; rocblas_int c_stride = 0;
rocblas_int d_stride = 0; rocblas_int d_stride = 0;
rocblas_datatype arg_type = rocblas_datatype_f32_r; rocblas_datatype arg_type = rocblas_datatype_f32_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r; rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r; rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true; bool strided_batched = true;
bool is_3inputs = true; bool is_3inputs = true;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument MIGRAPHX_DEVICE_EXPORT
gather(hipStream_t stream, argument result, argument arg1, argument arg2, int64_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument MIGRAPHX_DEVICE_EXPORT pad(hipStream_t stream,
argument result,
argument arg1,
float value,
std::vector<std::int64_t> pads);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_gather
{
op::gather op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::gather"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/pad.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_pad
{
op::pad op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::pad"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Support.h> #include <mlir-c/Support.h>
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 4
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck // Only undefine when not using cppcheck
#ifndef CPPCHECK #ifndef CPPCHECK
...@@ -321,31 +321,30 @@ struct mlir_program ...@@ -321,31 +321,30 @@ struct mlir_program
return result; return result;
} }
MlirType make_tensor(const shape& s) const MlirType make_mlir_shaped(const shape& s) const
{ {
if(not s.standard())
MIGRAPHX_THROW("MLIR expects all tensors to be in standard shape");
if(s.dynamic()) if(s.dynamic())
MIGRAPHX_THROW("MLIR does not support dynamic shapes"); MIGRAPHX_THROW("MLIR does not support dynamic shapes");
std::vector<int64_t> lens(s.lens().begin(), s.lens().end()); std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
return mlirRankedTensorTypeGet( std::vector<int64_t> strides(s.strides().begin(), s.strides().end());
lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull()); return rocmlirMIXRShapedTypeGet(
lens.size(), lens.data(), strides.data(), make_type(s.type()));
} }
template <class Range> template <class Range>
std::vector<MlirType> make_tensors(const Range& r) std::vector<MlirType> make_mlir_shapeds(const Range& r)
{ {
std::vector<MlirType> result; std::vector<MlirType> result;
std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) { std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) {
return make_tensor(s); return make_mlir_shaped(s);
}); });
return result; return result;
} }
MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs) MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs)
{ {
auto in = make_tensors(inputs); auto in = make_mlir_shapeds(inputs);
auto out = make_tensors(outputs); auto out = make_mlir_shapeds(outputs);
return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data()); return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data());
} }
...@@ -507,11 +506,7 @@ struct mlir_program ...@@ -507,11 +506,7 @@ struct mlir_program
mlir_operation_state& add_results(const std::vector<shape>& outputs) mlir_operation_state& add_results(const std::vector<shape>& outputs)
{ {
std::vector<shape> reshaped(outputs.size()); auto x = prog->make_mlir_shapeds(outputs);
std::transform(outputs.begin(), outputs.end(), reshaped.begin(), [](const shape& r) {
return shape{r.type(), r.lens()};
});
auto x = prog->make_tensors(reshaped);
if(not x.empty()) if(not x.empty())
{ {
mlirOperationStateAddResults(&op_state, x.size(), x.data()); mlirOperationStateAddResults(&op_state, x.size(), x.data());
...@@ -584,7 +579,7 @@ struct mlir_program ...@@ -584,7 +579,7 @@ struct mlir_program
std::vector<shape> outputs = m.get_output_shapes(); std::vector<shape> outputs = m.get_output_shapes();
std::vector<MlirLocation> arg_locs(inputs.size(), location); std::vector<MlirLocation> arg_locs(inputs.size(), location);
auto body_inputs = make_tensors(inputs); auto body_inputs = make_mlir_shapeds(inputs);
mlir_region region = mlirRegionCreate(); mlir_region region = mlirRegionCreate();
mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data()); mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data());
MlirBlock result = fbody.get(); MlirBlock result = fbody.get();
...@@ -610,7 +605,7 @@ struct mlir_program ...@@ -610,7 +605,7 @@ struct mlir_program
return "func.return"; return "func.return";
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
return "tosa.const"; return "migraphx.literal";
} }
return "migraphx." + ins->name(); return "migraphx." + ins->name();
} }
...@@ -669,7 +664,8 @@ struct mlir_program ...@@ -669,7 +664,8 @@ struct mlir_program
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
literal r = ins->get_literal(); literal r = ins->get_literal();
MlirType tensor_type = make_tensor(ins->get_shape()); MlirType shaped_type = make_mlir_shaped(ins->get_shape());
MlirType tensor_type = rocmlirMIXRShapedTypeAsTensor(shaped_type);
MlirAttribute mlir_value_attr = MlirAttribute mlir_value_attr =
mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data()); mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data());
ops.add_attributes({{"value", mlir_value_attr}}); ops.add_attributes({{"value", mlir_value_attr}});
...@@ -947,35 +943,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs) ...@@ -947,35 +943,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
auto param = m.get_parameter(name); auto param = m.get_parameter(name);
if(input.standard()) if(input.standard())
continue; continue;
auto lens = input.lens(); auto new_param = m.add_parameter(name + ".0", input);
auto strides = input.strides();
std::vector<operation> ops;
if(input.transposed())
{
auto perm = find_permutation(input);
auto iperm = invert_permutation(perm);
lens = reorder_dims(lens, iperm);
strides = reorder_dims(strides, iperm);
ops.push_back(make_op("transpose", {{"permutation", perm}}));
}
if(input.broadcasted())
{
std::transform(lens.begin(),
lens.end(),
strides.begin(),
lens.begin(),
[](auto len, auto stride) -> std::size_t {
if(stride == 0)
return 1;
return len;
});
ops.push_back(make_op("multibroadcast", {{"out_lens", input.lens()}}));
}
auto new_param =
std::accumulate(ops.begin(),
ops.end(),
m.add_parameter(name + ".0", shape{input.type(), lens}),
[&](auto x, auto op) { return m.insert_instruction(param, op, x); });
m.replace_instruction(param, new_param); m.replace_instruction(param, new_param);
m.remove_instruction(param); m.remove_instruction(param);
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/pad.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_pad::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs, *this}.has(1).standard();
return op.compute_shape(inputs);
}
argument hip_pad::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::pad(ctx.get_stream().get(), args.back(), args.front(), op.value, op.pads);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -52,7 +52,6 @@ ...@@ -52,7 +52,6 @@
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp> #include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/gpu/allocation_model.hpp> #include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp> #include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp> #include <migraphx/gpu/compile_ops.hpp>
...@@ -150,7 +149,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -150,7 +149,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
eliminate_fp8{unsupported_fp8_ops}, eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops},
dead_code_elimination{}, dead_code_elimination{},
optimize_module{}, optimize_module{},
fuse_pointwise{}, fuse_pointwise{},
......
...@@ -141,9 +141,9 @@ TEST_CASE(conv) ...@@ -141,9 +141,9 @@ TEST_CASE(conv)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_convolution(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
return %0 : tensor<1x2x2x2xf32> return %0 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -160,15 +160,38 @@ module { ...@@ -160,15 +160,38 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(conv_nhwc)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x1x32x8>, <2x8x3x3xf32, 72x1x24x8> -> <1x2x2x2xf32, 8x1x4x2>
return %0 : !migraphx.shaped<1x2x2x2xf32, 8x1x4x2>
}
}
)__migraphx__";
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}, {128, 1, 32, 8}});
auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}, {72, 1, 24, 8}});
auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w);
m.add_return({conv});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(conv_add_relu) TEST_CASE(conv_add_relu)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_convolution_add_relu(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_convolution_add_relu(%arg0: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution %arg2, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %1 = migraphx.add %0, %arg0 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %2 = migraphx.relu %1 : <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
return %2 : tensor<1x2x2x2xf32> return %2 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add) ...@@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_quant_dot_add(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32> %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xi8, 20x4x1>, <1x4x3xi8, 12x3x1> -> <1x5x3xi32, 15x3x1>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32> %1 = migraphx.add %0, %arg2 : <1x5x3xi32, 15x3x1>, <1x5x3xi32, 15x3x1> -> <1x5x3xi32, 15x3x1>
return %1 : tensor<1x5x3xi32> return %1 : !migraphx.shaped<1x5x3xi32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -219,10 +242,10 @@ TEST_CASE(dot_add) ...@@ -219,10 +242,10 @@ TEST_CASE(dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_add(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.add %0, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : tensor<1x5x3xf32> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize) ...@@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor<1x2x2x2xf32>, %arg3: tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xi8>, tensor<2x8x3x3xi8>) -> tensor<1x2x2x2xi32> %0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> -> <1x2x2x2xi32, 8x4x2x1>
%1 = migraphx.dequantizelinear(%0, %arg2, %arg3) : (tensor<1x2x2x2xi32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xf32> %1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.quantizelinear(%1, %arg2, %arg3) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> %2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xi32, 8x4x2x1>
return %2 : tensor<1x2x2x2xi32> return %2 : !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -278,10 +301,10 @@ TEST_CASE(dot_convert) ...@@ -278,10 +301,10 @@ TEST_CASE(dot_convert)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_convert(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16> %1 = migraphx.convert %0 {target_type = 1 : i64} : <1x5x3xf32, 15x3x1> to <1x5x3xf16, 15x3x1>
return %1 : tensor<1x5x3xf16> return %1 : !migraphx.shaped<1x5x3xf16, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -304,10 +327,10 @@ TEST_CASE(dot_where) ...@@ -304,10 +327,10 @@ TEST_CASE(dot_where)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_where(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : tensor<1x5x3xf32> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
......
a5537f2f563d4975c7e6121a7eb260bbbfd9455a d69842226b47e5336568103541b071447caeb9bf
...@@ -48,5 +48,5 @@ struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8<DType>> ...@@ -48,5 +48,5 @@ struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8<DType>>
}; };
template struct gemm_2args_mm_8<migraphx::shape::float_type>; template struct gemm_2args_mm_8<migraphx::shape::float_type>;
template struct gemm_2args_mm_8<migraphx::shape::half_type>; // template struct gemm_2args_mm_8<migraphx::shape::half_type>;
template struct gemm_2args_mm_8<migraphx::shape::fp8e4m3fnuz_type>; template struct gemm_2args_mm_8<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -51,5 +51,5 @@ struct gemm_add_broadcast2 : verify_program<gemm_add_broadcast2<DType>> ...@@ -51,5 +51,5 @@ struct gemm_add_broadcast2 : verify_program<gemm_add_broadcast2<DType>>
}; };
template struct gemm_add_broadcast2<migraphx::shape::float_type>; template struct gemm_add_broadcast2<migraphx::shape::float_type>;
template struct gemm_add_broadcast2<migraphx::shape::half_type>; // template struct gemm_add_broadcast2<migraphx::shape::half_type>;
template struct gemm_add_broadcast2<migraphx::shape::fp8e4m3fnuz_type>; template struct gemm_add_broadcast2<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -63,7 +63,8 @@ def clang_format(against, apply=False, path=CLANG_FORMAT_PATH): ...@@ -63,7 +63,8 @@ def clang_format(against, apply=False, path=CLANG_FORMAT_PATH):
print(f"{git_clang_format} not installed. Skipping format.") print(f"{git_clang_format} not installed. Skipping format.")
return return
diff_flag = "" if apply else "--diff" diff_flag = "" if apply else "--diff"
run(f"{git_clang_format} --binary {clang_format} {diff_flag} {base}") run(f"{git_clang_format} --extensions c,cpp,hpp,h,cl,hip,in --binary {clang_format} {diff_flag} {base}"
)
def get_files_changed(against, ext=('py')): def get_files_changed(against, ext=('py')):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment