Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f3edca63
Commit
f3edca63
authored
Mar 09, 2023
by
ltqin
Browse files
bfloat16_t save and atomic
parent
da0bb989
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
2 deletions
+93
-2
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1.cpp
...softmax_gemm/batched_multihead_attention_backward_pt1.cpp
+1
-1
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+70
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+1
-1
include/ck/utility/generic_memory_space_atomic.hpp
include/ck/utility/generic_memory_space_atomic.hpp
+21
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1.cpp
View file @
f3edca63
...
...
@@ -50,7 +50,7 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
b
half
_t
;
using
BF16
=
ck
::
b
float16
_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
f3edca63
...
...
@@ -296,6 +296,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bfloat16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
...
...
@@ -419,6 +420,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
return
bit_cast
<
bhalf8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
bfloat16_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
bit_cast
<
bfloat16x8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
...
...
@@ -552,6 +578,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bfloat16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
...
...
@@ -697,6 +724,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
bfloat16_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
bhalf_t
,
8
>
tmp
{
bit_cast
<
vector_type
<
bhalf_t
,
8
>>
(
src_thread_data
)};
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
bhalf_t
),
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
...
...
include/ck/utility/data_type.hpp
View file @
f3edca63
...
...
@@ -1052,7 +1052,7 @@ inline __host__ __device__ constexpr float type_convert<float, bfloat16_t>(bfloa
return
u
.
fp32
;
}
// convert fp32 to bfp16
// convert fp32 to bfp16
, rtz
template
<
>
inline
__host__
__device__
constexpr
bfloat16_t
type_convert
<
bfloat16_t
,
float
>
(
float
x
)
{
...
...
include/ck/utility/generic_memory_space_atomic.hpp
View file @
f3edca63
...
...
@@ -189,6 +189,27 @@ __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
return
x
;
}
template
<
>
__device__
bfloat16x2_t
atomic_add
<
bfloat16x2_t
>
(
bfloat16x2_t
*
p_dst
,
const
bfloat16x2_t
&
x
)
{
U32BF162_ADDR
dword_addr
;
U32BF162
cur_v
;
U32BF162
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
bf162_a
=
reinterpret_cast
<
bhalf2_t
*>
(
p_dst
);
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
bf162
=
add_bf16x2_t
(
cur_v
.
bf162
,
reinterpret_cast
<
bhalf2_t
>
(
x
));
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
return
x
;
}
// template <>
// __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
// {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment