Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
54e3ce2e
Commit
54e3ce2e
authored
Sep 14, 2022
by
Paul
Browse files
Format
parent
05605886
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
69 deletions
+71
-69
src/targets/gpu/kernels/include/migraphx/kernels/buffer_address.hpp
...s/gpu/kernels/include/migraphx/kernels/buffer_address.hpp
+65
-67
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+6
-2
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/buffer_address.hpp
View file @
54e3ce2e
...
@@ -22,110 +22,109 @@ __device__ vec<int32_t, 4> make_wave_buffer_resource(T* p_wave, index_int bytes)
...
@@ -22,110 +22,109 @@ __device__ vec<int32_t, 4> make_wave_buffer_resource(T* p_wave, index_int bytes)
};
};
resource
result
;
resource
result
;
result
.
address
[
0
]
=
p_wave
;
result
.
address
[
0
]
=
p_wave
;
result
.
chunk
[
2
]
=
bytes
;
result
.
chunk
[
2
]
=
bytes
;
result
.
chunk
[
3
]
=
0x00020000
;
// 0x31014000 for navi
result
.
chunk
[
3
]
=
0x00020000
;
// 0x31014000 for navi
return
result
.
content
;
return
result
.
content
;
}
}
template
<
class
T
>
template
<
class
T
>
struct
raw_buffer_load
;
struct
raw_buffer_load
;
#define MIGRAPHX_BUFFER_ADDR_VISIT_TYPES(m) \
#define MIGRAPHX_BUFFER_ADDR_VISIT_TYPES(m) m(i8, int8_t) m(i16, int16_t) m(f16, half) m(f32, float)
m(i8, int8_t) \
m(i16, int16_t) \
#define MIGRAPHX_RAW_BUFFER_LOAD(llvmtype, ...) \
m(f16, half) \
__device__ __VA_ARGS__ llvm_amdgcn_raw_buffer_load_##llvmtype( \
m(f32, float)
vec<int32_t, 4> srsrc, \
int32_t voffset, \
#define MIGRAPHX_RAW_BUFFER_LOAD(llvmtype, ...) \
int32_t soffset, \
__device__ __VA_ARGS__ \
int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load." #llvmtype); \
llvm_amdgcn_raw_buffer_load_ ## llvmtype(vec<int32_t, 4> srsrc, \
template <> \
int32_t voffset, \
struct raw_buffer_load<__VA_ARGS__> \
int32_t soffset, \
{ \
int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load." #llvmtype); \
static __device__ __VA_ARGS__ apply(vec<int32_t, 4> srsrc, \
template<> \
int32_t voffset, \
struct raw_buffer_load<__VA_ARGS__> { \
int32_t soffset, \
static __device__ __VA_ARGS__ apply(vec<int32_t, 4> srsrc, \
int32_t glc_slc) \
int32_t voffset, \
{ \
int32_t soffset, \
return llvm_amdgcn_raw_buffer_load_##llvmtype(srsrc, voffset, soffset, glc_slc); \
int32_t glc_slc) { \
} \
return llvm_amdgcn_raw_buffer_load_ ## llvmtype(srsrc, voffset, soffset, glc_slc); \
} \
};
};
#define MIGRAPHX_RAW_BUFFER_LOAD_VEC(llvmtype, cpptype) \
#define MIGRAPHX_RAW_BUFFER_LOAD_VEC(llvmtype, cpptype)
\
MIGRAPHX_RAW_BUFFER_LOAD(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_LOAD(llvmtype, cpptype)
\
MIGRAPHX_RAW_BUFFER_LOAD(v2
##
llvmtype, vec<cpptype, 2>) \
MIGRAPHX_RAW_BUFFER_LOAD(v2##llvmtype, vec<cpptype, 2>) \
MIGRAPHX_RAW_BUFFER_LOAD(v4
##
llvmtype, vec<cpptype, 4>)
MIGRAPHX_RAW_BUFFER_LOAD(v4##llvmtype, vec<cpptype, 4>)
MIGRAPHX_BUFFER_ADDR_VISIT_TYPES
(
MIGRAPHX_RAW_BUFFER_LOAD_VEC
)
MIGRAPHX_BUFFER_ADDR_VISIT_TYPES
(
MIGRAPHX_RAW_BUFFER_LOAD_VEC
)
#define MIGRAPHX_RAW_BUFFER_STORE(llvmtype, ...) \
#define MIGRAPHX_RAW_BUFFER_STORE(llvmtype, ...) \
__device__ void \
__device__ void llvm_amdgcn_raw_buffer_store_##llvmtype( \
llvm_amdgcn_raw_buffer_store_ ## llvmtype(__VA_ARGS__ vdata, vec<int32_t, 4> srsrc, \
__VA_ARGS__ vdata, \
int32_t voffset, \
vec<int32_t, 4> srsrc, \
int32_t soffset, \
int32_t voffset, \
int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store." #llvmtype); \
int32_t soffset, \
__device__ void raw_buffer_store(__VA_ARGS__ vdata, vec<int32_t, 4> srsrc, \
int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store." #llvmtype); \
int32_t voffset, \
__device__ void raw_buffer_store(__VA_ARGS__ vdata, \
int32_t soffset, \
vec<int32_t, 4> srsrc, \
int32_t glc_slc) { \
int32_t voffset, \
llvm_amdgcn_raw_buffer_store_ ## llvmtype(vdata, srsrc, voffset, soffset, glc_slc); \
int32_t soffset, \
}
int32_t glc_slc) \
{ \
#define MIGRAPHX_RAW_BUFFER_STORE_VEC(llvmtype, cpptype) \
llvm_amdgcn_raw_buffer_store_##llvmtype(vdata, srsrc, voffset, soffset, glc_slc); \
MIGRAPHX_RAW_BUFFER_STORE(llvmtype, cpptype) \
}
MIGRAPHX_RAW_BUFFER_STORE(v2 ## llvmtype, vec<cpptype, 2>) \
MIGRAPHX_RAW_BUFFER_STORE(v4 ## llvmtype, vec<cpptype, 4>)
#define MIGRAPHX_RAW_BUFFER_STORE_VEC(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_STORE(llvmtype, cpptype) \
MIGRAPHX_RAW_BUFFER_STORE(v2##llvmtype, vec<cpptype, 2>) \
MIGRAPHX_RAW_BUFFER_STORE(v4##llvmtype, vec<cpptype, 4>)
MIGRAPHX_BUFFER_ADDR_VISIT_TYPES
(
MIGRAPHX_RAW_BUFFER_STORE_VEC
)
MIGRAPHX_BUFFER_ADDR_VISIT_TYPES
(
MIGRAPHX_RAW_BUFFER_STORE_VEC
)
template
<
class
T
,
index_int
N
>
template
<
class
T
,
index_int
N
>
struct
raw_buffer_load
<
vec
<
T
,
N
>>
struct
raw_buffer_load
<
vec
<
T
,
N
>>
{
{
static
__device__
vec
<
T
,
N
>
apply
(
vec
<
int32_t
,
4
>
srsrc
,
static
__device__
vec
<
T
,
N
>
int32_t
voffset
,
apply
(
vec
<
int32_t
,
4
>
srsrc
,
int32_t
voffset
,
int32_t
soffset
,
int32_t
glc_slc
)
int32_t
soffset
,
int32_t
glc_slc
)
{
{
static_assert
(
N
%
2
==
0
,
"Invalid vector size"
);
static_assert
(
N
%
2
==
0
,
"Invalid vector size"
);
union
type
union
type
{
{
vec
<
T
,
N
>
data
;
vec
<
T
,
N
>
data
;
vec
<
T
,
N
/
2
>
reg
[
2
];
vec
<
T
,
N
/
2
>
reg
[
2
];
};
};
type
result
;
type
result
;
auto
offset
=
sizeof
(
T
)
*
(
N
/
2
);
auto
offset
=
sizeof
(
T
)
*
(
N
/
2
);
result
.
reg
[
0
]
=
raw_buffer_load
<
vec
<
T
,
N
/
2
>>::
apply
(
srsrc
,
voffset
+
offset
,
soffset
,
glc_slc
);
result
.
reg
[
0
]
=
result
.
reg
[
1
]
=
raw_buffer_load
<
vec
<
T
,
N
/
2
>>::
apply
(
srsrc
,
voffset
,
soffset
,
glc_slc
);
raw_buffer_load
<
vec
<
T
,
N
/
2
>>::
apply
(
srsrc
,
voffset
+
offset
,
soffset
,
glc_slc
);
result
.
reg
[
1
]
=
raw_buffer_load
<
vec
<
T
,
N
/
2
>>::
apply
(
srsrc
,
voffset
,
soffset
,
glc_slc
);
return
result
.
data
;
return
result
.
data
;
}
}
};
};
template
<
class
T
,
index_int
N
>
template
<
class
T
,
index_int
N
>
__device__
void
raw_buffer_store
(
vec
<
T
,
N
>
vdata
,
vec
<
int32_t
,
4
>
srsrc
,
__device__
void
raw_buffer_store
(
int32_t
voffset
,
vec
<
T
,
N
>
vdata
,
vec
<
int32_t
,
4
>
srsrc
,
int32_t
voffset
,
int32_t
soffset
,
int32_t
glc_slc
)
int32_t
soffset
,
{
int32_t
glc_slc
)
{
union
type
union
type
{
{
vec
<
T
,
N
>
data
;
vec
<
T
,
N
>
data
;
vec
<
T
,
N
/
2
>
reg
[
2
];
vec
<
T
,
N
/
2
>
reg
[
2
];
};
};
type
x
;
type
x
;
x
.
data
=
vdata
;
x
.
data
=
vdata
;
auto
offset
=
sizeof
(
T
)
*
(
N
/
2
);
auto
offset
=
sizeof
(
T
)
*
(
N
/
2
);
raw_buffer_store
(
x
.
reg
[
0
],
srsrc
,
voffset
,
soffset
,
glc_slc
);
raw_buffer_store
(
x
.
reg
[
0
],
srsrc
,
voffset
,
soffset
,
glc_slc
);
raw_buffer_store
(
x
.
reg
[
1
],
srsrc
,
voffset
+
offset
,
soffset
,
glc_slc
);
raw_buffer_store
(
x
.
reg
[
1
],
srsrc
,
voffset
+
offset
,
soffset
,
glc_slc
);
}
}
template
<
class
T
>
template
<
class
T
>
__device__
T
buffer_load
(
const
T
*
p
,
index_int
offset
,
index_int
size
,
address_space
::
global
)
__device__
T
buffer_load
(
const
T
*
p
,
index_int
offset
,
index_int
size
,
address_space
::
global
)
{
{
auto
resource
=
make_wave_buffer_resource
(
p
,
size
*
sizeof
(
T
));
auto
resource
=
make_wave_buffer_resource
(
p
,
size
*
sizeof
(
T
));
return
raw_buffer_load
<
T
>::
apply
(
resource
,
offset
*
sizeof
(
T
),
0
,
0
);
return
raw_buffer_load
<
T
>::
apply
(
resource
,
offset
*
sizeof
(
T
),
0
,
0
);
}
}
template
<
class
T
>
template
<
class
T
>
__device__
void
buffer_store
(
T
data
,
T
*
p
,
index_int
offset
,
index_int
size
,
address_space
::
global
)
__device__
void
buffer_store
(
T
data
,
T
*
p
,
index_int
offset
,
index_int
size
,
address_space
::
global
)
{
{
auto
resource
=
make_wave_buffer_resource
(
p
,
size
*
sizeof
(
T
));
auto
resource
=
make_wave_buffer_resource
(
p
,
size
*
sizeof
(
T
));
...
@@ -134,18 +133,17 @@ __device__ void buffer_store(T data, T* p, index_int offset, index_int size, add
...
@@ -134,18 +133,17 @@ __device__ void buffer_store(T data, T* p, index_int offset, index_int size, add
#endif
#endif
template
<
class
T
,
class
AddressSpace
>
template
<
class
T
,
class
AddressSpace
>
__device__
T
buffer_load
(
const
T
*
p
,
index_int
offset
,
index_int
,
AddressSpace
)
__device__
T
buffer_load
(
const
T
*
p
,
index_int
offset
,
index_int
,
AddressSpace
)
{
{
return
p
[
offset
];
return
p
[
offset
];
}
}
template
<
class
T
,
class
AddressSpace
>
template
<
class
T
,
class
AddressSpace
>
__device__
void
buffer_store
(
T
data
,
T
*
p
,
index_int
offset
,
index_int
,
AddressSpace
)
__device__
void
buffer_store
(
T
data
,
T
*
p
,
index_int
offset
,
index_int
,
AddressSpace
)
{
{
p
[
offset
]
=
data
;
p
[
offset
]
=
data
;
}
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_BUFFER_ADDRESS_HPP
#endif // MIGRAPHX_GUARD_KERNELS_BUFFER_ADDRESS_HPP
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
54e3ce2e
...
@@ -41,8 +41,12 @@ using half2 = migraphx::vec<half, 2>;
...
@@ -41,8 +41,12 @@ using half2 = migraphx::vec<half, 2>;
struct
address_space
struct
address_space
{
{
struct
global
{};
struct
global
struct
local
{};
{
};
struct
local
{
};
};
};
}
// namespace migraphx
}
// namespace migraphx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment