Commit 3b3b3a7c authored by Umang Yadav's avatar Umang Yadav
Browse files

ck_gemm assertion failure fix

parent dfde6d07
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights * in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is * copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions: * furnished to do so, subject to the following conditions:
* *
* The above copyright notice and this permission notice shall be included in * The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software. * all copies or substantial portions of the Software.
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP #ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP #define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
...@@ -37,36 +37,52 @@ namespace migraphx { ...@@ -37,36 +37,52 @@ namespace migraphx {
template <class Dims> template <class Dims>
constexpr auto ck_transposeb_dims(Dims dims) constexpr auto ck_transposeb_dims(Dims dims)
{ {
return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); }); return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); });
} }
template <class Tensor> template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens), using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides))); ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
using clock_value_t = long long;
__device__ void sleep(clock_value_t sleep_cycles)
{
clock_value_t start = clock64();
clock_value_t cycles_elapsed;
do { cycles_elapsed = clock64() - start; }
while (cycles_elapsed < sleep_cycles);
}
template <class G, class E, class A, class B, class... Ds> template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds) __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
{ {
constexpr auto desc = G::make_descriptor(to_ck_tensor<A>(), constexpr auto desc = G::make_descriptor(to_ck_tensor<A>(),
to_ck_tensor<ck_transposeb<B>>(), to_ck_tensor<ck_transposeb<B>>(),
ck::make_tuple(to_ck_tensor<Ds>()...), ck::make_tuple(to_ck_tensor<Ds>()...),
to_ck_tensor<E>()); to_ck_tensor<E>());
static_assert(desc.is_valid, "Invalid ck gemm.");
G::Run(desc, //static_assert(desc.is_valid, "Invalid ck gemm.");
to_ck_const_pointer(a.data()), if constexpr(not desc.is_valid)
to_ck_const_pointer(b.data()), {
ck::make_tuple(to_ck_const_pointer(ds.data())...), sleep(10000000);
to_ck_pointer(e.data())); return;
}
G::Run(desc,
to_ck_const_pointer(a.data()),
to_ck_const_pointer(b.data()),
ck::make_tuple(to_ck_const_pointer(ds.data())...),
to_ck_pointer(e.data()));
} }
template <class G, index_int BlocksPerBatch, class... Ts> template <class G, index_int BlocksPerBatch, class... Ts>
__device__ void ck_gemm(Ts... xs) __device__ void ck_gemm(Ts... xs)
{ {
gemm_batch_args(make_index(), _c<BlocksPerBatch>, xs...)( gemm_batch_args(make_index(), _c<BlocksPerBatch>, xs...)(
[](auto... ys) { ck_gemm_matrix<G>(ys...); }); [](auto... ys) { ck_gemm_matrix<G>(ys...); });
} }
} // namespace migraphx } // namespace migraphx
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment