Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
05605886
Commit
05605886
authored
Sep 14, 2022
by
Paul
Browse files
Use buffer addressing
parent
7662d9c0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
191 additions
and
7 deletions
+191
-7
src/targets/gpu/kernels/include/migraphx/kernels/buffer_address.hpp
...s/gpu/kernels/include/migraphx/kernels/buffer_address.hpp
+151
-0
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
+3
-3
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
...gets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
+30
-3
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+6
-0
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/buffer_address.hpp
0 → 100644
View file @
05605886
#ifndef MIGRAPHX_GUARD_KERNELS_BUFFER_ADDRESS_HPP
#define MIGRAPHX_GUARD_KERNELS_BUFFER_ADDRESS_HPP
#include <migraphx/kernels/types.hpp>
#ifndef MIGRAPHX_HAS_BUFFER_ADDR
#define MIGRAPHX_HAS_BUFFER_ADDR 1
#endif
namespace
migraphx
{
#if MIGRAPHX_HAS_BUFFER_ADDR
template
<
typename
T
>
__device__
vec
<
int32_t
,
4
>
make_wave_buffer_resource
(
T
*
p_wave
,
index_int
bytes
)
{
union
resource
{
// using ptr = const void*;
vec
<
int32_t
,
4
>
content
;
const
void
*
address
[
2
];
int32_t
chunk
[
4
];
};
resource
result
;
result
.
address
[
0
]
=
p_wave
;
result
.
chunk
[
2
]
=
bytes
;
result
.
chunk
[
3
]
=
0x00020000
;
// 0x31014000 for navi
return
result
.
content
;
}
template
<
class
T
>
struct
raw_buffer_load
;
#define MIGRAPHX_BUFFER_ADDR_VISIT_TYPES(m) \
m(i8, int8_t) \
m(i16, int16_t) \
m(f16, half) \
m(f32, float)
#define MIGRAPHX_RAW_BUFFER_LOAD(llvmtype, ...) \
__device__ __VA_ARGS__ \
llvm_amdgcn_raw_buffer_load_ ## llvmtype(vec<int32_t, 4> srsrc, \
int32_t voffset, \
int32_t soffset, \
int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load." #llvmtype); \
template<> \
struct raw_buffer_load<__VA_ARGS__> { \
static __device__ __VA_ARGS__ apply(vec<int32_t, 4> srsrc, \
int32_t voffset, \
int32_t soffset, \
int32_t glc_slc) { \
return llvm_amdgcn_raw_buffer_load_ ## llvmtype(srsrc, voffset, soffset, glc_slc); \
} \
};
#define MIGRAPHX_RAW_BUFFER_LOAD_VEC(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_LOAD(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_LOAD(v2 ## llvmtype, vec<cpptype, 2>) \
MIGRAPHX_RAW_BUFFER_LOAD(v4 ## llvmtype, vec<cpptype, 4>)
MIGRAPHX_BUFFER_ADDR_VISIT_TYPES
(
MIGRAPHX_RAW_BUFFER_LOAD_VEC
)
#define MIGRAPHX_RAW_BUFFER_STORE(llvmtype, ...) \
__device__ void \
llvm_amdgcn_raw_buffer_store_ ## llvmtype(__VA_ARGS__ vdata, vec<int32_t, 4> srsrc, \
int32_t voffset, \
int32_t soffset, \
int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store." #llvmtype); \
__device__ void raw_buffer_store(__VA_ARGS__ vdata, vec<int32_t, 4> srsrc, \
int32_t voffset, \
int32_t soffset, \
int32_t glc_slc) { \
llvm_amdgcn_raw_buffer_store_ ## llvmtype(vdata, srsrc, voffset, soffset, glc_slc); \
}
#define MIGRAPHX_RAW_BUFFER_STORE_VEC(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_STORE(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_STORE(v2 ## llvmtype, vec<cpptype, 2>) \
MIGRAPHX_RAW_BUFFER_STORE(v4 ## llvmtype, vec<cpptype, 4>)
MIGRAPHX_BUFFER_ADDR_VISIT_TYPES
(
MIGRAPHX_RAW_BUFFER_STORE_VEC
)
template
<
class
T
,
index_int
N
>
struct
raw_buffer_load
<
vec
<
T
,
N
>>
{
static
__device__
vec
<
T
,
N
>
apply
(
vec
<
int32_t
,
4
>
srsrc
,
int32_t
voffset
,
int32_t
soffset
,
int32_t
glc_slc
)
{
static_assert
(
N
%
2
==
0
,
"Invalid vector size"
);
union
type
{
vec
<
T
,
N
>
data
;
vec
<
T
,
N
/
2
>
reg
[
2
];
};
type
result
;
auto
offset
=
sizeof
(
T
)
*
(
N
/
2
);
result
.
reg
[
0
]
=
raw_buffer_load
<
vec
<
T
,
N
/
2
>>::
apply
(
srsrc
,
voffset
+
offset
,
soffset
,
glc_slc
);
result
.
reg
[
1
]
=
raw_buffer_load
<
vec
<
T
,
N
/
2
>>::
apply
(
srsrc
,
voffset
,
soffset
,
glc_slc
);
return
result
.
data
;
}
};
template
<
class
T
,
index_int
N
>
__device__
void
raw_buffer_store
(
vec
<
T
,
N
>
vdata
,
vec
<
int32_t
,
4
>
srsrc
,
int32_t
voffset
,
int32_t
soffset
,
int32_t
glc_slc
)
{
union
type
{
vec
<
T
,
N
>
data
;
vec
<
T
,
N
/
2
>
reg
[
2
];
};
type
x
;
x
.
data
=
vdata
;
auto
offset
=
sizeof
(
T
)
*
(
N
/
2
);
raw_buffer_store
(
x
.
reg
[
0
],
srsrc
,
voffset
,
soffset
,
glc_slc
);
raw_buffer_store
(
x
.
reg
[
1
],
srsrc
,
voffset
+
offset
,
soffset
,
glc_slc
);
}
template
<
class
T
>
__device__
T
buffer_load
(
const
T
*
p
,
index_int
offset
,
index_int
size
,
address_space
::
global
)
{
auto
resource
=
make_wave_buffer_resource
(
p
,
size
*
sizeof
(
T
));
return
raw_buffer_load
<
T
>::
apply
(
resource
,
offset
*
sizeof
(
T
),
0
,
0
);
}
template
<
class
T
>
__device__
void
buffer_store
(
T
data
,
T
*
p
,
index_int
offset
,
index_int
size
,
address_space
::
global
)
{
auto
resource
=
make_wave_buffer_resource
(
p
,
size
*
sizeof
(
T
));
return
raw_buffer_store
(
data
,
resource
,
offset
*
sizeof
(
T
),
0
,
0
);
}
#endif
template
<
class
T
,
class
AddressSpace
>
__device__
T
buffer_load
(
const
T
*
p
,
index_int
offset
,
index_int
,
AddressSpace
)
{
return
p
[
offset
];
}
template
<
class
T
,
class
AddressSpace
>
__device__
void
buffer_store
(
T
data
,
T
*
p
,
index_int
offset
,
index_int
,
AddressSpace
)
{
p
[
offset
]
=
data
;
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_BUFFER_ADDRESS_HPP
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
View file @
05605886
...
...
@@ -69,7 +69,7 @@ template <class F, class T, class... Ts>
__device__
void
pointwise_tensor
(
index
idx
,
F
f
,
T
out
,
Ts
...
xs
)
{
idx
.
global_stride
(
out
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
out
[
i
]
=
implicit_conversion
(
f
(
xs
[
i
]
...));
});
[
&
](
auto
i
)
{
out
.
store
(
i
,
implicit_conversion
(
f
(
xs
.
load
(
i
)
...))
)
;
});
}
template
<
class
...
Transforms
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
View file @
05605886
...
...
@@ -102,14 +102,14 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
auto
b
=
as_vec
(
tensor_vec_size
(
v
),
buffer
+
offset
);
idx
.
local_stride
(
v
.
get_shape
().
element_space
(),
[
&
](
auto
i
)
{
b
[
i
]
=
v
.
data
()[
i
];
});
return
x
.
with
(
buffer
+
offset
);
return
x
.
with
(
buffer
+
offset
,
address_space
::
local
{}
);
}
else
{
auto
b
=
as_vec
(
tensor_vec_size
(
x
),
buffer
+
offset
);
idx
.
local_stride
(
x
.
get_shape
().
element_space
(),
[
&
](
auto
i
)
{
b
[
i
]
=
x
.
data
()[
i
];
});
return
x
.
with
(
b
);
return
x
.
with
(
b
,
address_space
::
local
{}
);
}
}
else
...
...
@@ -172,7 +172,7 @@ __device__ auto preload_copy(index idx, T x)
auto
input
=
as_vec
<
n
>
(
remove_bool
(
x
.
data
()));
auto
b
=
as_vec
<
n
>
(
remove_bool
(
buffer
));
idx
.
local_stride
(
size
/
n
,
[
&
](
auto
i
)
{
b
[
i
]
=
input
[
i
];
});
return
f
(
x
.
with
(
buffer
));
return
f
(
x
.
with
(
buffer
,
address_space
::
local
{}
));
}
else
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
View file @
05605886
...
...
@@ -27,6 +27,7 @@
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp>
#include <migraphx/kernels/buffer_address.hpp>
namespace
migraphx
{
...
...
@@ -41,7 +42,7 @@ struct tensor_view_iterator_read
}
};
template
<
class
T
,
class
Shape
>
template
<
class
T
,
class
Shape
,
class
AddressSpace
=
address_space
::
global
>
struct
tensor_view
{
using
type
=
T
;
...
...
@@ -71,6 +72,26 @@ struct tensor_view
return
x
[
ito
.
offset
];
}
__device__
T
load
(
MIGRAPHX_CAPTURE_SOURCE_LOCATION
(
index_to_offset
)
i
)
const
{
index_to_offset
ito
=
i
;
MIGRAPHX_WARN
(
ito
.
offset
<
get_shape
().
element_space
(),
i
,
"Out of bounds access at offset: "
,
ito
.
offset
);
return
buffer_load
(
x
,
ito
.
offset
,
get_shape
().
element_space
(),
AddressSpace
{});
}
__device__
void
store
(
MIGRAPHX_CAPTURE_SOURCE_LOCATION
(
index_to_offset
)
i
,
T
d
)
const
{
index_to_offset
ito
=
i
;
MIGRAPHX_WARN
(
ito
.
offset
<
get_shape
().
element_space
(),
i
,
"Out of bounds access at offset: "
,
ito
.
offset
);
buffer_store
(
d
,
x
,
ito
.
offset
,
get_shape
().
element_space
(),
AddressSpace
{});
}
constexpr
T
*
data
()
const
{
return
x
;
}
constexpr
auto
begin
()
const
{
return
iterator
{
0
,
{
this
}};
}
...
...
@@ -83,13 +104,19 @@ struct tensor_view
return
iterator
{
get_shape
().
single
(
i
),
{
this
}};
}
template
<
class
U
>
constexpr
tensor_view
<
U
,
Shape
>
with
(
U
*
y
)
const
template
<
class
U
,
class
AS
>
constexpr
tensor_view
<
U
,
Shape
,
AS
>
with
(
U
*
y
,
AS
)
const
{
static_assert
(
sizeof
(
T
)
==
sizeof
(
U
),
"Not the same size"
);
return
{
y
};
}
template
<
class
U
>
constexpr
auto
with
(
U
*
y
)
const
{
return
with
(
y
,
AddressSpace
{});
}
T
*
x
;
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
05605886
...
...
@@ -39,6 +39,12 @@ using vec = T __attribute__((ext_vector_type(N)));
using
half
=
_Float16
;
using
half2
=
migraphx
::
vec
<
half
,
2
>
;
struct
address_space
{
struct
global
{};
struct
local
{};
};
}
// namespace migraphx
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment