Commit 7ea9c9c4 authored by Chao Liu's avatar Chao Liu
Browse files

Merge branch 'fix_0813' into fused-gemm

parents 2564c493 8bea6b2d
...@@ -10,9 +10,9 @@ ...@@ -10,9 +10,9 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -186,9 +186,9 @@ int main(int argc, char* argv[]) ...@@ -186,9 +186,9 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); b_k_n.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_B_REGISTER_HPP // SPDX-License-Identifier: MIT
#define CK_BLOCKWISE_GEMM_XDLOPS_B_REGISTER_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...@@ -317,4 +319,3 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -317,4 +319,3 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
}; };
} // namespace ck } // namespace ck
#endif
#ifndef DEVICE_GEMM_XDL_SKIP_B_LDS_HPP // SPDX-License-Identifier: MIT
#define DEVICE_GEMM_XDL_SKIP_B_LDS_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -11,8 +13,9 @@ ...@@ -11,8 +13,9 @@
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -518,4 +521,3 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -518,4 +521,3 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
#ifndef CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V1_HPP // SPDX-License-Identifier: MIT
#define CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V1_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
...@@ -674,4 +676,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -674,4 +676,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
}; };
} // namespace ck } // namespace ck
#endif
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATIC_BUFFER_HPP #pragma once
#define CK_STATIC_BUFFER_HPP
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
...@@ -20,13 +19,6 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -20,13 +19,6 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ constexpr StaticBuffer& operator=(StaticBuffer& y)
{
StaticBuffer& x = *this;
static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; });
return x;
}
template <typename... Ys> template <typename... Ys>
__host__ __device__ constexpr StaticBuffer& operator=(const Tuple<Ys...>& y) __host__ __device__ constexpr StaticBuffer& operator=(const Tuple<Ys...>& y)
{ {
...@@ -201,4 +193,3 @@ __host__ __device__ constexpr auto make_static_buffer(LongNumber<N>) ...@@ -201,4 +193,3 @@ __host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
} }
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment