Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
010ef9dc
Commit
010ef9dc
authored
Feb 26, 2022
by
rocking
Browse files
replace ushortXXX_t to bhalfXXX_t
parent
63e10e34
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
17 deletions
+18
-17
composable_kernel/include/utility/amd_buffer_addressing.hpp
composable_kernel/include/utility/amd_buffer_addressing.hpp
+7
-7
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+4
-4
composable_kernel/include/utility/data_type.hpp
composable_kernel/include/utility/data_type.hpp
+7
-6
No files found.
composable_kernel/include/utility/amd_buffer_addressing.hpp
View file @
010ef9dc
...
@@ -57,13 +57,13 @@ llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
...
@@ -57,13 +57,13 @@ llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.i16"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.i16"
);
__device__
ushort
2_t
__device__
bhalf
2_t
llvm_amdgcn_raw_buffer_load_i16x2
(
int32x4_t
srsrc
,
llvm_amdgcn_raw_buffer_load_i16x2
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
voffset
,
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v2i16"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v2i16"
);
__device__
ushort
4_t
__device__
bhalf
4_t
llvm_amdgcn_raw_buffer_load_i16x4
(
int32x4_t
srsrc
,
llvm_amdgcn_raw_buffer_load_i16x4
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
voffset
,
index_t
soffset
,
index_t
soffset
,
...
@@ -156,14 +156,14 @@ llvm_amdgcn_raw_buffer_store_i16(ushort vdata,
...
@@ -156,14 +156,14 @@ llvm_amdgcn_raw_buffer_store_i16(ushort vdata,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i16"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i16"
);
__device__
void
__device__
void
llvm_amdgcn_raw_buffer_store_i16x2
(
ushort
2_t
vdata
,
llvm_amdgcn_raw_buffer_store_i16x2
(
bhalf
2_t
vdata
,
int32x4_t
rsrc
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
voffset
,
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2i16"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2i16"
);
__device__
void
__device__
void
llvm_amdgcn_raw_buffer_store_i16x4
(
ushort
4_t
vdata
,
llvm_amdgcn_raw_buffer_store_i16x4
(
bhalf
4_t
vdata
,
int32x4_t
rsrc
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
voffset
,
index_t
soffset
,
index_t
soffset
,
...
@@ -387,7 +387,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
...
@@ -387,7 +387,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
bit_cast
<
ushort
8_t
>
(
tmp
);
return
bit_cast
<
bhalf
8_t
>
(
tmp
);
}
}
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
...
@@ -655,13 +655,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
...
@@ -655,13 +655,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
{
{
3
vector_type
<
ushort
,
8
>
tmp
{
src_thread_data
};
3
vector_type
<
ushort
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
ushort
4_t
>
()[
Number
<
0
>
{}],
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf
4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
,
0
);
0
);
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
ushort
4_t
>
()[
Number
<
1
>
{}],
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf
4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
ushort
),
dst_wave_addr_offset
+
4
*
sizeof
(
ushort
),
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
010ef9dc
...
@@ -207,7 +207,7 @@ template <>
...
@@ -207,7 +207,7 @@ template <>
struct
intrin_mfma_f32_32x32x8bf16_1k
<
32
,
32
>
struct
intrin_mfma_f32_32x32x8bf16_1k
<
32
,
32
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
ushort
4_t
&
reg_a
,
const
ushort
4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
bhalf
4_t
&
reg_a
,
const
bhalf
4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
...
@@ -221,7 +221,7 @@ template <>
...
@@ -221,7 +221,7 @@ template <>
struct
intrin_mfma_f32_16x16x16bf16_1k
<
16
,
16
>
struct
intrin_mfma_f32_16x16x16bf16_1k
<
16
,
16
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
ushort
4_t
&
reg_a
,
const
ushort
4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
bhalf
4_t
&
reg_a
,
const
bhalf
4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
...
@@ -235,7 +235,7 @@ template <>
...
@@ -235,7 +235,7 @@ template <>
struct
intrin_mfma_f32_32x32x4bf16
<
32
,
32
>
struct
intrin_mfma_f32_32x32x4bf16
<
32
,
32
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
ushort
2_t
&
reg_a
,
const
ushort
2_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
bhalf
2_t
&
reg_a
,
const
bhalf
2_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
...
@@ -249,7 +249,7 @@ template <>
...
@@ -249,7 +249,7 @@ template <>
struct
intrin_mfma_f32_16x16x8bf16
<
16
,
16
>
struct
intrin_mfma_f32_16x16x8bf16
<
16
,
16
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
ushort
2_t
&
reg_a
,
const
ushort
2_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
bhalf
2_t
&
reg_a
,
const
bhalf
2_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
...
...
composable_kernel/include/utility/data_type.hpp
View file @
010ef9dc
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
namespace
ck
{
namespace
ck
{
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
// vector_type
// vector_type
...
@@ -904,12 +905,12 @@ using half32_t = typename vector_type<half_t, 32>::type;
...
@@ -904,12 +905,12 @@ using half32_t = typename vector_type<half_t, 32>::type;
using
half64_t
=
typename
vector_type
<
half_t
,
64
>::
type
;
using
half64_t
=
typename
vector_type
<
half_t
,
64
>::
type
;
// bfp16
// bfp16
using
ushort
2_t
=
typename
vector_type
<
ushor
t
,
2
>::
type
;
using
bhalf
2_t
=
typename
vector_type
<
bhalf_
t
,
2
>::
type
;
using
ushort
4_t
=
typename
vector_type
<
ushor
t
,
4
>::
type
;
using
bhalf
4_t
=
typename
vector_type
<
bhalf_
t
,
4
>::
type
;
using
ushort
8_t
=
typename
vector_type
<
ushor
t
,
8
>::
type
;
using
bhalf
8_t
=
typename
vector_type
<
bhalf_
t
,
8
>::
type
;
using
ushort
16_t
=
typename
vector_type
<
ushor
t
,
16
>::
type
;
using
bhalf
16_t
=
typename
vector_type
<
bhalf_
t
,
16
>::
type
;
using
ushort
32_t
=
typename
vector_type
<
ushor
t
,
32
>::
type
;
using
bhalf
32_t
=
typename
vector_type
<
bhalf_
t
,
32
>::
type
;
using
ushort
64_t
=
typename
vector_type
<
ushor
t
,
64
>::
type
;
using
bhalf
64_t
=
typename
vector_type
<
bhalf_
t
,
64
>::
type
;
// i32
// i32
using
int32x2_t
=
typename
vector_type
<
int32_t
,
2
>::
type
;
using
int32x2_t
=
typename
vector_type
<
int32_t
,
2
>::
type
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment