"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "492d329a44f24ac8e1fcb3f2bf355793cef50497"
Commit 781ce146 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

add fp16 fixes

parent 6d937d80
...@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype( ...@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(
inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
{ {
index_int groups = (n + local - 1) / local; index_int groups = (n + local - 1) / local;
index_int nglobal = std::min<index_int>(256, groups) * local; index_int nglobal = std::min<index_int>(1048576, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) __device__ { launch(stream, nglobal, local)([=](auto idx) __device__ {
......
...@@ -87,7 +87,7 @@ void nary_broadcast_vec_impl( ...@@ -87,7 +87,7 @@ void nary_broadcast_vec_impl(
const index_int vec_size = 4; const index_int vec_size = 4;
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = 512 * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size; const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)( hip_vec_visit_all<vec_size>(result, barg, args...)(
[&](auto output, auto binput, auto... inputs) { [&](auto output, auto binput, auto... inputs) {
...@@ -134,7 +134,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -134,7 +134,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = 512 * nlocal;
index_int nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) { hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
...@@ -178,7 +178,7 @@ void nary_double_broadcast_vec_impl( ...@@ -178,7 +178,7 @@ void nary_double_broadcast_vec_impl(
const index_int vec_size = 4; const index_int vec_size = 4;
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = 512 * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size; const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)( hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) { [&](auto output, auto binput1, auto binput2, auto... inputs) {
...@@ -234,7 +234,7 @@ void nary_double_broadcast_impl( ...@@ -234,7 +234,7 @@ void nary_double_broadcast_impl(
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = 512 * nlocal;
index_int nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)( hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) { [&](auto output, auto binput1, auto binput2, auto... inputs) {
......
...@@ -60,12 +60,17 @@ void gemm_impl( ...@@ -60,12 +60,17 @@ void gemm_impl(
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
auto compute_type = output_type; auto compute_type = output_type;
if(ctx.get_stream().get_device_name() == "gfx908")
{
if(args[0].get_shape().type() == shape::half_type)
compute_type = rocblas_datatype_f32_r;
}
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
...@@ -91,14 +96,14 @@ void gemm_impl( ...@@ -91,14 +96,14 @@ void gemm_impl(
n, n,
m, m,
k, k,
&alpha_r, &alpha,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
&beta_r, &beta,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
...@@ -123,7 +128,7 @@ void gemm_impl( ...@@ -123,7 +128,7 @@ void gemm_impl(
n, n,
m, m,
k, k,
&alpha_r, &alpha,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
...@@ -132,7 +137,7 @@ void gemm_impl( ...@@ -132,7 +137,7 @@ void gemm_impl(
arg_type, arg_type,
lda, lda,
m * k, m * k,
&beta_r, &beta,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
......
...@@ -87,6 +87,17 @@ struct hip_device ...@@ -87,6 +87,17 @@ struct hip_device
return rbhandle.get(); return rbhandle.get();
} }
std::string get_device_name()
{
hipDeviceProp_t props{};
// int device;
// if (not (hipGetDevice(&device) == hipSuccess))
// MIGRAPHX_THROW("Unable to get hip device");
// if (not (hipGetDeviceProperties(&props, device) == hipSuccess))
// MIGRAPHX_THROW("Unable to get hip device properties");
return "gfx" + std::to_string(props.gcnArch);
}
void wait(hipEvent_t event) void wait(hipEvent_t event)
{ {
setup(); setup();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment