"...composable_kernel_rocm.git" did not exist on "3da5c19e629174c234fe86c17ebd04732ea548b7"
Commit f4965d63 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Cleanup first interface

parent d4881b8a
......@@ -3,19 +3,41 @@
#include "host_interface.hpp"
namespace ck {
namespace app {
struct DeviceMem
{
float* ptr_mem=nullptr;
int size;
DeviceMem(int _size): size(_size){}
float* GetDeviceBuffer()
{
return ptr_mem;
}
DeviceMem() = delete;
DeviceMem(std::size_t mem_size);
void* GetDeviceBuffer();
void ToDevice(const void* p);
void FromDevice(void* p);
~DeviceMem();
void* mpDeviceBuf;
std::size_t mMemSize;
};
namespace ck {
namespace app {
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p)
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p)
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
void profile_conv_fwd_impl(int do_verification,
int init_method,
......@@ -49,9 +71,9 @@ void profile_conv_fwd_impl(int do_verification,
using InDataType = float;
using OutDataType = float;
DeviceMem in_device_buf(sizeof(InDataType) * in_sz);
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_sz);
DeviceMem out_device_buf(sizeof(OutDataType) * out_sz);
app::DeviceMem in_device_buf(sizeof(InDataType) * in_sz);
app::DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_sz);
app::DeviceMem out_device_buf(sizeof(OutDataType) * out_sz);
// data is already on device!
......
# device_conv2d_fwd_instance
set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance2.cpp;
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
......
......@@ -3,6 +3,7 @@
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
#include "host_interface.hpp"
namespace ck {
namespace tensor_operation {
......@@ -106,3 +107,85 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
{
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr,
size_t N, size_t K, size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return el->MakeArgumentPointer(in_ptr, wei_ptr, out_ptr, N, K, C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides,
conv_filter_dilations, input_left_pads, input_right_pads, PassThrough{}, PassThrough{}, PassThrough{});
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> MakeInvokerPointer()
{
return el->MakeInvokerPointer();
}
std::string GetTypeString()
{
return el->GetTypeString();
}
bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg)
{
return el->IsSupportedArgument(arg);
}
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough> el;
};
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr){}
// DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& impl) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(impl)) {}
DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(std::move(other))){}
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr,
size_t N, size_t K, size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return pImpl->MakeArgumentPointer(in_ptr, wei_ptr, out_ptr, N, K, C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides,
conv_filter_dilations, input_left_pads, input_right_pads);
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvokerPointer()
{
return pImpl->MakeInvokerPointer();
}
std::string DeviceConvFwdPtr_t::GetTypeString()
{
return pImpl->GetTypeString();
}
bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr)
{
return pImpl->IsSupportedArgument(arg_ptr);
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances)
{
using namespace ck::tensor_operation::device::device_conv2d_fwd_instance;
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances);
// convert local_instances to instances
for(auto& kinder: local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better
}
return;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment