Commit 05605886 authored by Paul's avatar Paul
Browse files

Use buffer addressing

parent 7662d9c0
#ifndef MIGRAPHX_GUARD_KERNELS_BUFFER_ADDRESS_HPP
#define MIGRAPHX_GUARD_KERNELS_BUFFER_ADDRESS_HPP
#include <migraphx/kernels/types.hpp>
#ifndef MIGRAPHX_HAS_BUFFER_ADDR
#define MIGRAPHX_HAS_BUFFER_ADDR 1
#endif
namespace migraphx {
#if MIGRAPHX_HAS_BUFFER_ADDR
template <typename T>
__device__ vec<int32_t, 4> make_wave_buffer_resource(T* p_wave, index_int bytes)
{
union resource
{
// using ptr = const void*;
vec<int32_t, 4> content;
const void* address[2];
int32_t chunk[4];
};
resource result;
result.address[0] = p_wave;
result.chunk[2] = bytes;
result.chunk[3] = 0x00020000; // 0x31014000 for navi
return result.content;
}
template<class T>
struct raw_buffer_load;
#define MIGRAPHX_BUFFER_ADDR_VISIT_TYPES(m) \
m(i8, int8_t) \
m(i16, int16_t) \
m(f16, half) \
m(f32, float)
#define MIGRAPHX_RAW_BUFFER_LOAD(llvmtype, ...) \
__device__ __VA_ARGS__ \
llvm_amdgcn_raw_buffer_load_ ## llvmtype(vec<int32_t, 4> srsrc, \
int32_t voffset, \
int32_t soffset, \
int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load." #llvmtype); \
template<> \
struct raw_buffer_load<__VA_ARGS__> { \
static __device__ __VA_ARGS__ apply(vec<int32_t, 4> srsrc, \
int32_t voffset, \
int32_t soffset, \
int32_t glc_slc) { \
return llvm_amdgcn_raw_buffer_load_ ## llvmtype(srsrc, voffset, soffset, glc_slc); \
} \
};
#define MIGRAPHX_RAW_BUFFER_LOAD_VEC(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_LOAD(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_LOAD(v2 ## llvmtype, vec<cpptype, 2>) \
MIGRAPHX_RAW_BUFFER_LOAD(v4 ## llvmtype, vec<cpptype, 4>)
MIGRAPHX_BUFFER_ADDR_VISIT_TYPES(MIGRAPHX_RAW_BUFFER_LOAD_VEC)
#define MIGRAPHX_RAW_BUFFER_STORE(llvmtype, ...) \
__device__ void \
llvm_amdgcn_raw_buffer_store_ ## llvmtype(__VA_ARGS__ vdata, vec<int32_t, 4> srsrc, \
int32_t voffset, \
int32_t soffset, \
int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store." #llvmtype); \
__device__ void raw_buffer_store(__VA_ARGS__ vdata, vec<int32_t, 4> srsrc, \
int32_t voffset, \
int32_t soffset, \
int32_t glc_slc) { \
llvm_amdgcn_raw_buffer_store_ ## llvmtype(vdata, srsrc, voffset, soffset, glc_slc); \
}
#define MIGRAPHX_RAW_BUFFER_STORE_VEC(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_STORE(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_STORE(v2 ## llvmtype, vec<cpptype, 2>) \
MIGRAPHX_RAW_BUFFER_STORE(v4 ## llvmtype, vec<cpptype, 4>)
MIGRAPHX_BUFFER_ADDR_VISIT_TYPES(MIGRAPHX_RAW_BUFFER_STORE_VEC)
template<class T, index_int N>
struct raw_buffer_load<vec<T, N>>
{
static __device__ vec<T, N> apply(vec<int32_t, 4> srsrc,
int32_t voffset,
int32_t soffset,
int32_t glc_slc)
{
static_assert(N % 2 == 0, "Invalid vector size");
union type
{
vec<T, N> data;
vec<T, N/2> reg[2];
};
type result;
auto offset = sizeof(T) * (N/2);
result.reg[0] = raw_buffer_load<vec<T, N/2>>::apply(srsrc, voffset+offset, soffset, glc_slc);
result.reg[1] = raw_buffer_load<vec<T, N/2>>::apply(srsrc, voffset, soffset, glc_slc);
return result.data;
}
};
template<class T, index_int N>
__device__ void raw_buffer_store(vec<T, N> vdata, vec<int32_t, 4> srsrc,
int32_t voffset,
int32_t soffset,
int32_t glc_slc) {
union type
{
vec<T, N> data;
vec<T, N/2> reg[2];
};
type x;
x.data = vdata;
auto offset = sizeof(T) * (N/2);
raw_buffer_store(x.reg[0], srsrc, voffset, soffset, glc_slc);
raw_buffer_store(x.reg[1], srsrc, voffset+offset, soffset, glc_slc);
}
template<class T>
__device__ T buffer_load(const T* p, index_int offset, index_int size, address_space::global)
{
auto resource = make_wave_buffer_resource(p, size * sizeof(T));
return raw_buffer_load<T>::apply(resource, offset * sizeof(T), 0, 0);
}
template<class T>
__device__ void buffer_store(T data, T* p, index_int offset, index_int size, address_space::global)
{
auto resource = make_wave_buffer_resource(p, size * sizeof(T));
return raw_buffer_store(data, resource, offset * sizeof(T), 0, 0);
}
#endif
template<class T, class AddressSpace>
__device__ T buffer_load(const T* p, index_int offset, index_int, AddressSpace)
{
return p[offset];
}
template<class T, class AddressSpace>
__device__ void buffer_store(T data, T* p, index_int offset, index_int, AddressSpace)
{
p[offset] = data;
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_BUFFER_ADDRESS_HPP
......@@ -69,7 +69,7 @@ template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{
idx.global_stride(out.get_shape().elements(),
[&](auto i) { out[i] = implicit_conversion(f(xs[i]...)); });
[&](auto i) { out.store(i, implicit_conversion(f(xs.load(i)...))); });
}
template <class... Transforms>
......
......@@ -102,14 +102,14 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
return x.with(buffer + offset);
return x.with(buffer + offset, address_space::local{});
}
else
{
auto b = as_vec(tensor_vec_size(x), buffer + offset);
idx.local_stride(x.get_shape().element_space(),
[&](auto i) { b[i] = x.data()[i]; });
return x.with(b);
return x.with(b, address_space::local{});
}
}
else
......@@ -172,7 +172,7 @@ __device__ auto preload_copy(index idx, T x)
auto input = as_vec<n>(remove_bool(x.data()));
auto b = as_vec<n>(remove_bool(buffer));
idx.local_stride(size / n, [&](auto i) { b[i] = input[i]; });
return f(x.with(buffer));
return f(x.with(buffer, address_space::local{}));
}
else
{
......
......@@ -27,6 +27,7 @@
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp>
#include <migraphx/kernels/buffer_address.hpp>
namespace migraphx {
......@@ -41,7 +42,7 @@ struct tensor_view_iterator_read
}
};
template <class T, class Shape>
template <class T, class Shape, class AddressSpace = address_space::global>
struct tensor_view
{
using type = T;
......@@ -71,6 +72,26 @@ struct tensor_view
return x[ito.offset];
}
__device__ T load(MIGRAPHX_CAPTURE_SOURCE_LOCATION(index_to_offset) i) const
{
index_to_offset ito = i;
MIGRAPHX_WARN(ito.offset < get_shape().element_space(),
i,
"Out of bounds access at offset: ",
ito.offset);
return buffer_load(x, ito.offset, get_shape().element_space(), AddressSpace{});
}
__device__ void store(MIGRAPHX_CAPTURE_SOURCE_LOCATION(index_to_offset) i, T d) const
{
index_to_offset ito = i;
MIGRAPHX_WARN(ito.offset < get_shape().element_space(),
i,
"Out of bounds access at offset: ",
ito.offset);
buffer_store(d, x, ito.offset, get_shape().element_space(), AddressSpace{});
}
constexpr T* data() const { return x; }
constexpr auto begin() const { return iterator{0, {this}}; }
......@@ -83,13 +104,19 @@ struct tensor_view
return iterator{get_shape().single(i), {this}};
}
template <class U>
constexpr tensor_view<U, Shape> with(U* y) const
template <class U, class AS>
constexpr tensor_view<U, Shape, AS> with(U* y, AS) const
{
static_assert(sizeof(T) == sizeof(U), "Not the same size");
return {y};
}
template <class U>
constexpr auto with(U* y) const
{
return with(y, AddressSpace{});
}
T* x;
};
......
......@@ -39,6 +39,12 @@ using vec = T __attribute__((ext_vector_type(N)));
using half = _Float16;
using half2 = migraphx::vec<half, 2>;
struct address_space
{
struct global {};
struct local {};
};
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment