Commit 0b11569f authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute

parents e8d3a0fb fa9a0a5c
add_subdirectory(src/host_tensor)
add_subdirectory(src/tensor_operation_instance/gpu) add_subdirectory(src/tensor_operation_instance/gpu)
add_subdirectory(src/host_tensor)
add_subdirectory(src/utility) add_subdirectory(src/utility)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <vector> #include <vector>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "conv_common.hpp" #include "conv_common.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <vector> #include <vector>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <thread> #include <thread>
...@@ -219,6 +222,12 @@ struct Tensor ...@@ -219,6 +222,12 @@ struct Tensor
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {} Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {}
Tensor& operator=(const Tensor& other)
{
mDesc = other.mDesc;
mData = other.mData;
}
template <typename F> template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank) void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
{ {
...@@ -361,13 +370,8 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens, ...@@ -361,13 +370,8 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
{ {
} }
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
#if 1 #if 1
// FIXME: remove // FIXME: remove
void bf16_to_f32_(const Tensor<ck::bhalf_t>& src, Tensor<float>& dst);
#endif
template <typename T> template <typename T>
float check_error(const Tensor<T>& ref, const Tensor<T>& result) float check_error(const Tensor<T>& ref, const Tensor<T>& result)
{ {
...@@ -413,3 +417,4 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -413,3 +417,4 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
return linf_error; return linf_error;
} }
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <cmath> #include <cmath>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
...@@ -59,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -59,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
float v_a; ADataType v_a;
float v_b; BDataType v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_g_m_k_(g, m, k))); arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_g_k_n_(g, k, n))); arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc += v_a * v_b; v_acc += ck::type_convert<float>(v_a) * ck::type_convert<float>(v_b);
} }
float v_c; float v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_g_m_n_(g, m, n) = v_c; arg.c_g_m_n_(g, m, n) = ck::type_convert<CDataType>(v_c);
}; };
make_ParallelTensorFunctor(f_gmk_gkn_gmn, make_ParallelTensorFunctor(f_gmk_gkn_gmn,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
...@@ -60,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -60,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
AccDataType v_a; ADataType v_a;
AccDataType v_b; BDataType v_b;
arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k))); arg.a_element_op_(v_a, arg.a_m_k_(m, k));
arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n))); arg.b_element_op_(v_b, arg.b_k_n_(k, n));
v_acc += v_a * v_b; v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
AccDataType v_c; AccDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = v_c; arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c);
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment