Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5d11ef6d
Commit
5d11ef6d
authored
May 25, 2023
by
carlushuang
Browse files
support dynamic buffer using memory coherence glc_slc bit from template
parent
ac9e01e2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
191 additions
and
104 deletions
+191
-104
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+181
-99
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+10
-5
No files found.
include/ck/utility/amd_buffer_addressing.hpp
View file @
5d11ef6d
...
...
@@ -286,7 +286,18 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
template
<
typename
T
,
index_t
N
>
// memory coherency bit for buffer store/load instruction
enum
struct
amd_buffer_coherence_bits
{
default_coherence
=
0
,
// default value
glc
=
1
,
slc
=
2
,
glc_slc
=
3
,
};
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_bits
coherence
=
amd_buffer_coherence_bits
::
default_coherence
>
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
...
...
@@ -305,28 +316,37 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
// use fp32 load to mimic fp64 load
if
constexpr
(
N
==
1
)
{
const
float2_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
const
float2_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
double
>
(
tmp
);
}
else
if
constexpr
(
N
==
2
)
{
const
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
const
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
double2_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
const
float4_t
f32_0
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
const
float4_t
f32_0
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
const
float4_t
f32_1
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
0
);
static_cast
<
index_t
>
(
coherence
)
);
vector_type
<
double
,
4
>
tmp
;
tmp
.
AsType
<
double2_t
>
()(
Number
<
0
>
{})
=
bit_cast
<
double2_t
>
(
f32_0
);
...
...
@@ -339,31 +359,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_fp32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_fp32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
;
tmp
.
AsType
<
float4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
tmp
.
AsType
<
float4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
float4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
0
);
static_cast
<
index_t
>
(
coherence
)
);
return
tmp
.
AsType
<
float8_t
>
()(
Number
<
0
>
{});
}
...
...
@@ -372,24 +401,32 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_fp16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_fp16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
// use fp32 load to mimic fp16 load
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
half8_t
>
(
tmp
);
}
...
...
@@ -398,23 +435,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
bhalf8_t
>
(
tmp
);
}
...
...
@@ -423,31 +468,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
int32_t
,
8
>
tmp
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
0
);
static_cast
<
index_t
>
(
coherence
)
);
return
tmp
.
AsType
<
int32x8_t
>
()(
Number
<
0
>
{});
}
}
...
...
@@ -455,17 +509,23 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return
llvm_amdgcn_raw_buffer_load_i8x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_i8x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x2_t
>
(
tmp
);
#endif
...
...
@@ -473,11 +533,15 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
else
if
constexpr
(
N
==
4
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x4_t
>
(
tmp
);
#endif
...
...
@@ -487,19 +551,24 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
8
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int8_t
),
0
);
static_cast
<
index_t
>
(
coherence
)
);
return
tmp
.
AsType
<
int8x8_t
>
()(
Number
<
0
>
{});
#else
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x8_t
>
(
tmp
);
#endif
...
...
@@ -509,31 +578,36 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
16
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int8_t
),
0
);
static_cast
<
index_t
>
(
coherence
)
);
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
2
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
8
*
sizeof
(
int8_t
),
0
);
static_cast
<
index_t
>
(
coherence
)
);
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
3
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
12
*
sizeof
(
int8_t
),
0
);
static_cast
<
index_t
>
(
coherence
)
);
return
tmp
.
AsType
<
int8x16_t
>
()(
Number
<
0
>
{});
#else
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x16_t
>
(
tmp
);
#endif
...
...
@@ -541,7 +615,9 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
}
}
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_bits
coherence
=
amd_buffer_coherence_bits
::
default_coherence
>
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
...
...
@@ -565,7 +641,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
2
)
{
...
...
@@ -573,7 +649,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
...
...
@@ -584,7 +660,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
2
)
{
...
...
@@ -592,7 +668,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
4
)
{
...
...
@@ -600,7 +676,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
...
...
@@ -611,7 +687,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
2
)
{
...
...
@@ -619,7 +695,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
4
)
{
...
...
@@ -627,7 +703,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
8
)
{
...
...
@@ -638,19 +714,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0
);
static_cast<index_t>(coherence)
);
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0
);
static_cast<index_t>(coherence)
);
#else
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
#endif
}
}
...
...
@@ -662,7 +738,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
2
)
{
...
...
@@ -670,7 +746,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
4
)
{
...
...
@@ -678,7 +754,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
8
)
{
...
...
@@ -688,13 +764,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
bhalf_t
),
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
...
...
@@ -705,7 +781,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
2
)
{
...
...
@@ -713,7 +789,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
4
)
{
...
...
@@ -721,7 +797,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
...
...
@@ -732,7 +808,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
2
)
{
...
...
@@ -741,13 +817,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
#else
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
#endif
}
else
if
constexpr
(
N
==
4
)
...
...
@@ -757,13 +833,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
#else
llvm_amdgcn_raw_buffer_store_i32
(
bit_cast
<
int32_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
#endif
}
else
if
constexpr
(
N
==
8
)
...
...
@@ -772,7 +848,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
else
if
constexpr
(
N
==
16
)
{
...
...
@@ -780,7 +856,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
}
}
...
...
@@ -1012,7 +1088,9 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::typ
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_bits
coherence
=
amd_buffer_coherence_bits
::
default_coherence
>
__device__
typename
vector_type_maker
<
T
,
N
>::
type
::
type
amd_buffer_load_invalid_element_return_zero
(
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
...
...
@@ -1032,10 +1110,10 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#else
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
...
...
@@ -1046,7 +1124,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_bits
coherence
=
amd_buffer_coherence_bits
::
default_coherence
>
__device__
typename
vector_type_maker
<
T
,
N
>::
type
::
type
amd_buffer_load_invalid_element_return_customized_value
(
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
...
...
@@ -1064,7 +1144,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
customized_value
);
...
...
@@ -1074,7 +1154,9 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_bits
coherence
=
amd_buffer_coherence_bits
::
default_coherence
>
__device__
void
amd_buffer_store
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
...
...
@@ -1093,12 +1175,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
amd_buffer_store_impl
<
scalar_t
,
vector_size
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
amd_buffer_store_impl
<
scalar_t
,
vector_size
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
5d11ef6d
...
...
@@ -52,6 +52,7 @@ struct DynamicBuffer
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
amd_buffer_coherence_bits
coherence
=
amd_buffer_coherence_bits
::
default_coherence
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
...
...
@@ -77,13 +78,16 @@ struct DynamicBuffer
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
}
}
...
...
@@ -144,6 +148,7 @@ struct DynamicBuffer
}
template
<
typename
X
,
amd_buffer_coherence_bits
coherence
=
amd_buffer_coherence_bits
::
default_coherence
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
...
...
@@ -173,7 +178,7 @@ struct DynamicBuffer
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
&&
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment