Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9eebd123
Commit
9eebd123
authored
Apr 10, 2021
by
Chao Liu
Browse files
overhaul vector_type, make int8x4_t real vector instead of aliasing from int32_t
parent
5602817f
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
524 additions
and
454 deletions
+524
-454
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+1
-1
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+51
-42
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
+1
-1
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
...sable_kernel/include/utility/amd_buffer_addressing_v2.hpp
+116
-76
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+69
-39
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+5
-0
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+281
-295
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
9eebd123
...
...
@@ -104,7 +104,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
type
;
using
vector_t
=
typename
vector_type
_maker
<
Data
,
DataPerAccess
>::
type
::
type
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
9eebd123
...
...
@@ -172,16 +172,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
}();
// copy data
vector_type
<
DstData
,
DstScalarPerVector
>
dst_vector
;
typename
vector_type
_maker
<
DstData
,
DstScalarPerVector
>
::
type
dst_vector
;
using
dst_vector_t
=
typename
vector_type
<
DstData
,
DstScalarPerVector
>::
type
;
using
dst_vector_t
=
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
to_multi_index
(
src_slice_origin_idx
)
+
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
dst_vector
.
Scalars
()(
i
)
=
type_convert
<
DstData
>
{}(
p_src
[
Number
<
src_offset
>
{}]);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
{}(
p_src
[
Number
<
src_offset
>
{}]);
});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
...
...
@@ -192,7 +194,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2
<
DstData
,
DstScalarPerVector
>
(
dst_vector
.
Vector
(
),
dst_vector
.
template
AsType
<
dst_vector_t
>()(
Number
<
0
>
{}
),
p_dst
,
dst_slice_origin_coord_
.
GetOffset
(),
is_dst_valid
,
...
...
@@ -201,7 +203,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
if
(
is_dst_valid
)
{
*
reinterpret_cast
<
dst_vector_t
*>
(
&
(
p_dst
[
dst_slice_origin_coord_
.
GetOffset
()]))
=
dst_vector
.
Vector
();
&
(
p_dst
[
dst_slice_origin_coord_
.
GetOffset
()]))
=
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}];
}
#endif
}
...
...
@@ -210,7 +213,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
if
(
is_dst_valid
)
{
*
reinterpret_cast
<
dst_vector_t
*>
(
&
(
p_dst
[
dst_slice_origin_coord_
.
GetOffset
()]))
=
dst_vector
.
Vector
();
&
(
p_dst
[
dst_slice_origin_coord_
.
GetOffset
()]))
=
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}];
}
}
...
...
@@ -500,9 +504,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// copy data
static_assert
(
DstAddressSpace
==
AddressSpace
::
Vgpr
,
"wrong! hardcode for vgpr dst"
);
vector_type
<
SrcData
,
SrcScalarPerVector
>
src_vector
;
typename
vector_type
_maker
<
SrcData
,
SrcScalarPerVector
>
::
type
src_vector
;
using
src_vector_t
=
typename
vector_type
<
SrcData
,
SrcScalarPerVector
>::
type
;
using
src_vector_t
=
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
...
...
@@ -510,24 +515,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
src_vector
.
Vector
()
=
amd_buffer_load_v2
<
SrcData
,
SrcScalarPerVector
>
(
p_src
,
src_slice_origin_coord_
.
GetOffset
(),
is_src_valid
,
src_desc
.
GetElementSpaceSize
());
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
amd_buffer_load_v2
<
SrcData
,
SrcScalarPerVector
>
(
p_src
,
src_slice_origin_coord_
.
GetOffset
(),
is_src_valid
,
src_desc
.
GetElementSpaceSize
());
#else
src_vector
.
Vector
()
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
#endif
}
else
{
src_vector
.
Vector
()
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
}
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
...
...
@@ -535,7 +541,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
i
*
src_scalar_step_in_vector
);
p_dst
[
Number
<
dst_offset
>
{}]
=
src_vector
.
Scalars
()[
i
];
p_dst
[
Number
<
dst_offset
>
{}]
=
src_vector
.
template
AsType
<
SrcData
>
()[
i
];
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
...
...
@@ -833,9 +839,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
}();
// copy data
vector_type
<
SrcData
,
SrcScalarPerVector
>
src_vector
;
typename
vector_type
_maker
<
SrcData
,
SrcScalarPerVector
>
::
type
src_vector
;
using
src_vector_t
=
typename
vector_type
<
SrcData
,
SrcScalarPerVector
>::
type
;
using
src_vector_t
=
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
...
...
@@ -843,31 +850,32 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
src_vector
.
Vector
()
=
amd_buffer_load_v2
<
SrcData
,
SrcScalarPerVector
>
(
p_src
,
src_slice_origin_coord_
.
GetOffset
(),
is_src_valid
,
src_desc
.
GetElementSpaceSize
());
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
amd_buffer_load_v2
<
SrcData
,
SrcScalarPerVector
>
(
p_src
,
src_slice_origin_coord_
.
GetOffset
(),
is_src_valid
,
src_desc
.
GetElementSpaceSize
());
#else
src_vector
.
Vector
()
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
#endif
}
else
{
src_vector
.
Vector
()
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
}
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
buffer_offset
=
buffer_desc_
.
CalculateOffset
(
src_data_idx
+
i
*
src_scalar_step_in_vector
);
buffer_
(
Number
<
buffer_offset
>
{})
=
src_vector
.
Scalars
()[
i
];
buffer_
(
Number
<
buffer_offset
>
{})
=
src_vector
.
template
AsType
<
SrcData
>
()[
i
];
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
...
...
@@ -1018,19 +1026,20 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
DstInMemOp
==
InMemoryDataOperation
::
Set
,
"wrong! hardcoded for ds_write"
);
vector_type
<
DstData
,
DstScalarPerVector
>
dst_vector
;
typename
vector_type
_maker
<
DstData
,
DstScalarPerVector
>
::
type
dst_vector
;
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
buffer_offset
=
buffer_desc_
.
CalculateOffset
(
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
dst_vector
.
Scalars
()(
i
)
=
buffer_
[
Number
<
buffer_offset
>
{}];
dst_vector
.
template
AsType
<
DstData
>
()(
i
)
=
buffer_
[
Number
<
buffer_offset
>
{}];
});
using
DstVectorType
=
typename
vector_type
<
DstData
,
DstScalarPerVector
>::
type
;
using
DstVectorType
=
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
*
reinterpret_cast
<
DstVectorType
*>
(
p_dst
+
dst_slice_origin_coord_
.
GetOffset
())
=
dst_vector
.
Vector
()
;
dst_vector
.
template
AsType
<
DstVectorType
>()[
Number
<
0
>
{}]
;
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
View file @
9eebd123
...
...
@@ -41,7 +41,7 @@ struct ThreadwiseMatrixSliceCopy_v2
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
type
;
using
vector_t
=
typename
vector_type
_maker
<
Data
,
DataPerAccess
>::
type
::
type
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
...
...
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
View file @
9eebd123
...
...
@@ -209,16 +209,12 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
static_assert
((
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
half2_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
half4_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
half8_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32x2_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
int32x4_t
>::
value
&&
(
N
==
1
)),
"wrong! not implemented"
);
static_assert
(
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
...
...
@@ -241,16 +237,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{
vector_type
<
float
,
8
>
tmp
;
tmp
.
Vectors
(
Number
<
4
>
{}
)(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_fp32x4
(
tmp
.
AsType
<
float4_t
>
(
)(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
tmp
.
Vectors
(
Number
<
4
>
{}
)(
Number
<
1
>
{})
=
tmp
.
AsType
<
float4_t
>
(
)(
Number
<
1
>
{})
=
__llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
0
);
return
tmp
.
Vector
(
);
return
tmp
.
AsType
<
float8_t
>
()(
Number
<
0
>
{}
);
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
...
...
@@ -270,39 +266,20 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
return
__llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
half2_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
__llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
half4_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
__llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
half8_t
>::
value
)
{
if
constexpr
(
N
==
1
)
else
if
constexpr
(
N
==
8
)
{
vector_type
<
half_t
,
8
>
tmp
;
tmp
.
Vectors
(
Number
<
4
>
{}
)(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_fp16x4
(
tmp
.
AsType
<
half4_t
>
(
)(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
tmp
.
Vectors
(
Number
<
4
>
{}
)(
Number
<
1
>
{})
=
tmp
.
AsType
<
half4_t
>
(
)(
Number
<
1
>
{})
=
__llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
half_t
),
0
);
return
tmp
.
Vector
(
);
return
tmp
.
AsType
<
half8_t
>
()(
Number
<
0
>
{}
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
...
...
@@ -326,15 +303,15 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
{
vector_type
<
int32_t
,
8
>
tmp
;
tmp
.
Vectors
(
Number
<
4
>
{}
)(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i32x4
(
tmp
.
AsType
<
int32x4_t
>
(
)(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
tmp
.
Vectors
(
Number
<
4
>
{}
)(
Number
<
1
>
{})
=
tmp
.
AsType
<
int32x4_t
>
(
)(
Number
<
1
>
{})
=
__llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
0
);
return
tmp
.
Vector
(
);
return
tmp
.
AsType
<
int32x8_t
>
()(
Number
<
0
>
{}
);
}
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
...
...
@@ -346,44 +323,83 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
else
if
constexpr
(
N
==
2
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
return
__llvm_amdgcn_raw_buffer_load_i8x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
#else
int16_t
tmp
=
__llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
int8x2_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
4
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
return
__llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
#else
int32_t
tmp
=
__llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
int8x4_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
8
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
vector_type
<
int8_t
,
8
>
tmp
;
tmp
.
Vectors
(
Number
<
4
>
{}
)(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
tmp
.
AsType
<
int8x4_t
>
(
)(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
tmp
.
Vectors
(
Number
<
4
>
{}
)(
Number
<
1
>
{})
=
tmp
.
AsType
<
int8x4_t
>
(
)(
Number
<
1
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int8_t
),
0
);
return
tmp
.
Vector
();
}
}
else
if
constexpr
(
is_same
<
T
,
int32x2_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
__llvm_amdgcn_raw_buffer_load_i32x2
(
return
tmp
.
AsType
<
int8x8_t
>
()(
Number
<
0
>
{});
#else
int32x2_t
tmp
=
__llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
int8x8_t
>
(
tmp
);
#endif
}
}
else
if
constexpr
(
is_same
<
T
,
int32x4_t
>::
value
)
{
if
constexpr
(
N
==
1
)
else
if
constexpr
(
N
==
16
)
{
return
__llvm_amdgcn_raw_buffer_load_i32x4
(
#if !CK_WORKAROUND_SWDEV_XXXXXX
vector_type
<
int8_t
,
16
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int8_t
),
0
);
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
2
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
8
*
sizeof
(
int8_t
),
0
);
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
3
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
12
*
sizeof
(
int8_t
),
0
);
return
tmp
.
AsType
<
int8x16_t
>
()(
Number
<
0
>
{});
#else
int32x4_t
tmp
=
__llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
int8x16_t
>
(
tmp
);
#endif
}
}
}
...
...
@@ -467,23 +483,39 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
}
else
if
constexpr
(
N
==
2
)
{
__llvm_amdgcn_raw_buffer_store_i16
(
src_thread_data
,
#if !CK_WORKAROUND_SWDEV_XXXXXX
__llvm_amdgcn_raw_buffer_store_i8x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
#else
__llvm_amdgcn_raw_buffer_store_i16
(
as_type
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
#endif
}
else
if
constexpr
(
N
==
4
)
{
__llvm_amdgcn_raw_buffer_store_i32
(
src_thread_data
,
#if !CK_WORKAROUND_SWDEV_XXXXXX
__llvm_amdgcn_raw_buffer_store_i8x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
#else
__llvm_amdgcn_raw_buffer_store_i32
(
as_type
<
int32_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
#endif
}
else
if
constexpr
(
N
==
8
)
{
__llvm_amdgcn_raw_buffer_store_i32x2
(
src_thread_data
,
__llvm_amdgcn_raw_buffer_store_i32x2
(
as_type
<
int32x2_t
>
(
src_thread_data
)
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -491,7 +523,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
}
else
if
constexpr
(
N
==
16
)
{
__llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
,
__llvm_amdgcn_raw_buffer_store_i32x4
(
as_type
<
int32x4_t
>
(
src_thread_data
)
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -528,13 +560,13 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
{
vector_type
<
half_t
,
8
>
tmp
{
src_thread_data
};
__llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
Vectors
(
Number
<
4
>
{}
)[
Number
<
0
>
{}],
__llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
AsType
<
half4_t
>
(
)[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
__llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
Vectors
(
Number
<
4
>
{}
)[
Number
<
1
>
{}],
__llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
AsType
<
half4_t
>
(
)[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
half_t
),
...
...
@@ -548,26 +580,29 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
// 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_v2
(
const
T
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_element_space
)
__device__
typename
vector_type_maker
<
T
,
N
>::
type
::
type
amd_buffer_load_v2
(
const
T
*
p_src_wave
,
index_t
src_thread_data_offset
,
bool
src_thread_data_valid
,
index_t
src_element_space
)
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space
);
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_data_valid
?
0
:
0x7fffffff
;
return
amd_buffer_load_impl_v2
<
T
,
N
>
(
return
amd_buffer_load_impl_v2
<
scalar_t
,
vector_size
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#else
using
vector_t
=
typename
vector_type
<
T
,
N
>::
type
;
vector_t
tmp
=
amd_buffer_load_impl_v2
<
T
,
N
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
vector_t
tmp
=
amd_buffer_load_impl_v2
<
scalar_t
,
vector_size
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_data_valid
?
tmp
:
vector_t
(
0
);
#endif
...
...
@@ -578,26 +613,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wa
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_store_v2
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_data_offset
,
const
bool
dst_thread_data_valid
,
const
index_t
dst_element_space
)
__device__
void
amd_buffer_store_v2
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_data_offset
,
const
bool
dst_thread_data_valid
,
const
index_t
dst_element_space
)
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space
);
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_data_valid
?
0
:
0x7fffffff
;
amd_buffer_store_impl_v2
<
T
,
N
>
(
amd_buffer_store_impl_v2
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_data_valid
)
{
amd_buffer_store_impl_v2
<
T
,
N
>
(
amd_buffer_store_impl_v2
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
9eebd123
...
...
@@ -72,6 +72,7 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo
__device__
void
amd_assembly_outer_product_1x2
(
half4_t
a
,
half4_t
b0
,
half4_t
b1
,
float
&
c0
,
float
&
c1
)
{
// TODO remove pointer casting
const
half2_t
*
p_a_half2
=
reinterpret_cast
<
const
half2_t
*>
(
&
a
);
const
half2_t
*
p_b0_half2
=
reinterpret_cast
<
const
half2_t
*>
(
&
b0
);
const
half2_t
*
p_b1_half2
=
reinterpret_cast
<
const
half2_t
*>
(
&
b1
);
...
...
@@ -132,6 +133,7 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
float
&
c2
,
float
&
c3
)
{
// TODO remove pointer casting
const
half2_t
*
p_a_half2
=
reinterpret_cast
<
const
half2_t
*>
(
&
a
);
const
half2_t
*
p_b0_half2
=
reinterpret_cast
<
const
half2_t
*>
(
&
b0
);
const
half2_t
*
p_b1_half2
=
reinterpret_cast
<
const
half2_t
*>
(
&
b1
);
...
...
@@ -177,6 +179,7 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
float
&
c3
)
{
// TODO remove pointer casting
const
half4_t
*
p_a_half4
=
reinterpret_cast
<
const
half4_t
*>
(
&
a
);
const
half4_t
*
p_b0_half4
=
reinterpret_cast
<
const
half4_t
*>
(
&
b0
);
const
half4_t
*
p_b1_half4
=
reinterpret_cast
<
const
half4_t
*>
(
&
b1
);
...
...
@@ -200,6 +203,7 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
float
&
c2
,
float
&
c3
)
{
// TODO remove pointer casting
const
half8_t
*
p_a_half8
=
reinterpret_cast
<
const
half8_t
*>
(
&
a
);
const
half8_t
*
p_b0_half8
=
reinterpret_cast
<
const
half8_t
*>
(
&
b0
);
const
half8_t
*
p_b1_half8
=
reinterpret_cast
<
const
half8_t
*>
(
&
b1
);
...
...
@@ -224,10 +228,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
v_dot4_i32_i8 %1, %2, %4, %1
\n
\
"
:
"=v"
(
c0
),
"=v"
(
c1
)
:
"v"
(
a
),
"v"
(
b0
),
"v"
(
b1
),
"0"
(
c0
),
"1"
(
c1
));
:
"v"
(
as_type
<
int32_t
>
(
a
)),
"v"
(
as_type
<
int32_t
>
(
b0
)),
"v"
(
as_type
<
int32_t
>
(
b1
)),
"0"
(
c0
),
"1"
(
c1
));
#else
c0
=
__builtin_amdgcn_sdot4
(
a
,
b0
,
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
a
,
b1
,
c1
,
false
);
c0
=
__builtin_amdgcn_sdot4
(
a
s_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b0
)
,
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
a
s_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b1
)
,
c1
,
false
);
#endif
}
...
...
@@ -253,12 +261,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
v_dot4_i32_i8 %3, %4, %8, %3
\n
\
"
:
"=v"
(
c0
),
"=v"
(
c1
),
"=v"
(
c2
),
"=v"
(
c3
)
:
"v"
(
a
),
"v"
(
b0
),
"v"
(
b1
),
"v"
(
b2
),
"v"
(
b3
),
"0"
(
c0
),
"1"
(
c1
),
"2"
(
c2
),
"3"
(
c3
));
:
"v"
(
as_type
<
int32_t
>
(
a
)),
"v"
(
as_type
<
int32_t
>
(
b0
)),
"v"
(
as_type
<
int32_t
>
(
b1
)),
"v"
(
as_type
<
int32_t
>
(
b2
)),
"v"
(
as_type
<
int32_t
>
(
b3
)),
"0"
(
c0
),
"1"
(
c1
),
"2"
(
c2
),
"3"
(
c3
));
#else
c0
=
__builtin_amdgcn_sdot4
(
a
,
b0
,
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
a
,
b1
,
c1
,
false
);
c2
=
__builtin_amdgcn_sdot4
(
a
,
b2
,
c2
,
false
);
c3
=
__builtin_amdgcn_sdot4
(
a
,
b3
,
c3
,
false
);
c0
=
__builtin_amdgcn_sdot4
(
a
s_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b0
)
,
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
a
s_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b1
)
,
c1
,
false
);
c2
=
__builtin_amdgcn_sdot4
(
a
s_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b2
)
,
c2
,
false
);
c3
=
__builtin_amdgcn_sdot4
(
a
s_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b3
)
,
c3
,
false
);
#endif
}
...
...
@@ -272,28 +288,24 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int32_t
&
c2
,
int32_t
&
c3
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
const
int8x4_t
*
p_a_int8x4_t
=
reinterpret_cast
<
const
int8x4_t
*>
(
&
a
);
const
int8x4_t
*
p_b0_int8x4_t
=
reinterpret_cast
<
const
int8x4_t
*>
(
&
b0
);
const
int8x4_t
*
p_b1_int8x4_t
=
reinterpret_cast
<
const
int8x4_t
*>
(
&
b1
);
const
int8x4_t
*
p_b2_int8x4_t
=
reinterpret_cast
<
const
int8x4_t
*>
(
&
b2
);
const
int8x4_t
*
p_b3_int8x4_t
=
reinterpret_cast
<
const
int8x4_t
*>
(
&
b3
);
amd_assembly_outer_product_1x4
(
p_a_int8x4_t
[
0
],
p_b0_int8x4_t
[
0
],
p_b1_int8x4_t
[
0
],
p_b2_int8x4_t
[
0
],
p_b3_int8x4_t
[
0
],
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b0
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b1
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b2
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b3
}.
AsType
<
int8x4_t
>
()[
I0
],
c0
,
c1
,
c2
,
c3
);
amd_assembly_outer_product_1x4
(
p_a_
int8x4_t
[
1
],
p_b0_
int8x4_t
[
1
],
p_b1_
int8x4_t
[
1
],
p_b2_
int8x4_t
[
1
],
p_b3_
int8x4_t
[
1
],
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I
1
],
vector_type
<
int8_t
,
8
>
{
b0
}.
AsType
<
int8x4_t
>
()[
I
1
],
vector_type
<
int8_t
,
8
>
{
b1
}.
AsType
<
int8x4_t
>
()[
I
1
],
vector_type
<
int8_t
,
8
>
{
b2
}.
AsType
<
int8x4_t
>
()[
I
1
],
vector_type
<
int8_t
,
8
>
{
b3
}.
AsType
<
int8x4_t
>
()[
I
1
],
c0
,
c1
,
c2
,
...
...
@@ -311,28 +323,46 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int32_t
&
c3
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
int8x8_t
*
p_a_int8x8_t
=
reinterpret_cast
<
const
int8x8_t
*>
(
&
a
);
const
int8x8_t
*
p_b0_int8x8_t
=
reinterpret_cast
<
const
int8x8_t
*>
(
&
b0
);
const
int8x8_t
*
p_b1_int8x8_t
=
reinterpret_cast
<
const
int8x8_t
*>
(
&
b1
);
const
int8x8_t
*
p_b2_int8x8_t
=
reinterpret_cast
<
const
int8x8_t
*>
(
&
b2
);
const
int8x8_t
*
p_b3_int8x8_t
=
reinterpret_cast
<
const
int8x8_t
*>
(
&
b3
);
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b0
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b1
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b2
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b3
}.
AsType
<
int8x4_t
>
()[
I0
],
c0
,
c1
,
c2
,
c3
);
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b0
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b1
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b2
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b3
}.
AsType
<
int8x4_t
>
()[
I1
],
c0
,
c1
,
c2
,
c3
);
amd_assembly_outer_product_1x4
(
p_a_
int8x
8
_t
[
0
],
p_b0_
int8x
8
_t
[
0
],
p_b1_
int8x
8
_t
[
0
],
p_b2_
int8x
8
_t
[
0
],
p_b3_
int8x
8
_t
[
0
],
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x
4
_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b0
}.
AsType
<
int8x
4
_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b1
}.
AsType
<
int8x
4
_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b2
}.
AsType
<
int8x
4
_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b3
}.
AsType
<
int8x
4
_t
>
()[
I2
],
c0
,
c1
,
c2
,
c3
);
amd_assembly_outer_product_1x4
(
p_a_
int8x
8
_t
[
1
],
p_b0_
int8x
8
_t
[
1
],
p_b1_
int8x
8
_t
[
1
],
p_b2_
int8x
8
_t
[
1
],
p_b3_
int8x
8
_t
[
1
],
amd_assembly_outer_product_1x4
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x
4
_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b0
}.
AsType
<
int8x
4
_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b1
}.
AsType
<
int8x
4
_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b2
}.
AsType
<
int8x
4
_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b3
}.
AsType
<
int8x
4
_t
>
()[
I3
],
c0
,
c1
,
c2
,
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
9eebd123
...
...
@@ -142,6 +142,11 @@
#define CK_WORKAROUND_SWDEV_275126 1
#endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX
#define CK_WORKAROUND_SWDEV_XXXXXX 1
#endif
namespace ck {
enum AddressSpace
...
...
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
9eebd123
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment