Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
a91b68df
Commit
a91b68df
authored
Aug 13, 2021
by
Chao Liu
Browse files
DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element
parent
2cbabbba
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
176 additions
and
74 deletions
+176
-74
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+1
-1
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
+4
-2
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+3
-2
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+1
-1
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
.../tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
+1
-1
composable_kernel/include/utility/amd_buffer_addressing.hpp
composable_kernel/include/utility/amd_buffer_addressing.hpp
+64
-37
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+60
-24
composable_kernel/include/utility/static_buffer.hpp
composable_kernel/include/utility/static_buffer.hpp
+41
-5
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
a91b68df
...
@@ -133,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -133,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
// thread A buffer for GEMM
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
,
true
>
a_thread_buf
;
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
View file @
a91b68df
...
@@ -227,7 +227,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -227,7 +227,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// register allocation for output
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
c_thread_buf
;
// initialize output thread tensor
// initialize output thread tensor
...
@@ -251,7 +252,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -251,7 +252,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// double regsiter buffer for b
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
,
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
true
>
b_thread_even_buf
,
b_thread_odd_buf
;
b_thread_even_buf
,
b_thread_odd_buf
;
// LDS double buffer: preload data
// LDS double buffer: preload data
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
a91b68df
...
@@ -402,7 +402,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -402,7 +402,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
BlkSize
>
,
vector_type
<
FloatAcc
,
BlkSize
>
,
c_mr_nr_blk_desc
.
GetElementSpaceSize
()
>
c_mr_nr_blk_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
c_thread_buf
;
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -493,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -493,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
Number<M2>{},
Number<M2>{},
Number<1>{}));
Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()
, true
>
c_blk_buf_;
c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
a91b68df
...
@@ -1242,7 +1242,7 @@ struct ThreadwiseTensorSliceTransfer_v3
...
@@ -1242,7 +1242,7 @@ struct ThreadwiseTensorSliceTransfer_v3
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
SrcData
,
buffer_size_
>
buffer_
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
SrcData
,
buffer_size_
,
true
>
buffer_
;
SrcCoord
src_coord_
;
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
DstCoord
dst_coord_
;
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
View file @
a91b68df
...
@@ -602,7 +602,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -602,7 +602,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
SrcData
,
buffer_size_
>
buffer_
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
SrcData
,
buffer_size_
,
true
>
buffer_
;
SrcCoord
src_coord_
;
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
DstCoord
dst_coord_
;
...
...
composable_kernel/include/utility/amd_buffer_addressing.hpp
View file @
a91b68df
...
@@ -10,25 +10,25 @@ union BufferResource
...
@@ -10,25 +10,25 @@ union BufferResource
{
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t
data
;
int32x4_t
content
;
StaticallyIndexedArray
<
T
*
,
2
>
address
;
StaticallyIndexedArray
<
T
*
,
2
>
address
;
StaticallyIndexedArray
<
int32_t
,
4
>
range
;
StaticallyIndexedArray
<
int32_t
,
4
>
range
;
StaticallyIndexedArray
<
int32_t
,
4
>
config
;
StaticallyIndexedArray
<
int32_t
,
4
>
config
;
};
};
template
<
typename
T
>
template
<
typename
T
>
__device__
int32x4_t
make_wave_buffer_resource
(
T
*
p_wave
,
index_t
data
_space_size
)
__device__
int32x4_t
make_wave_buffer_resource
(
T
*
p_wave
,
index_t
element
_space_size
)
{
{
BufferResource
<
T
>
wave_buffer_resource
;
BufferResource
<
T
>
wave_buffer_resource
;
// wavewise base address (64 bit)
// wavewise base address (64 bit)
wave_buffer_resource
.
address
(
Number
<
0
>
{})
=
const_cast
<
remove_cv_t
<
T
>*>
(
p_wave
);
wave_buffer_resource
.
address
(
Number
<
0
>
{})
=
const_cast
<
remove_cv_t
<
T
>*>
(
p_wave
);
// wavewise range (32 bit)
// wavewise range (32 bit)
wave_buffer_resource
.
range
(
Number
<
2
>
{})
=
data
_space_size
*
sizeof
(
T
);
wave_buffer_resource
.
range
(
Number
<
2
>
{})
=
element
_space_size
*
sizeof
(
T
);
// wavewise setting (32 bit)
// wavewise setting (32 bit)
wave_buffer_resource
.
config
(
Number
<
3
>
{})
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
wave_buffer_resource
.
config
(
Number
<
3
>
{})
=
CK_BUFFER_RESOURCE_3RD_DWORD
;
return
wave_buffer_resource
.
data
;
return
wave_buffer_resource
.
content
;
}
}
// load
// load
...
@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
...
@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4f32"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4f32"
);
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type
<
T
,
N
>::
type
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
amd_buffer_load_impl_v2
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
index_t
src_wave_addr_offset
)
{
{
static_assert
(
static_assert
(
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
...
@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
}
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_store_impl
_v2
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
index_t
dst_wave_addr_offset
)
{
{
static_assert
(
static_assert
(
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
...
@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
...
@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
// buffer_load requires:
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 1) p_src_wave must be in global memory space
// 2) p_src_wave t
o
be a wavewise pointer.
// 2) p_src_wave
mus
t be a wavewise pointer.
// It is user's responsibility to make sure that is true.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type_maker
<
T
,
N
>::
type
::
type
__device__
typename
vector_type_maker
<
T
,
N
>::
type
::
type
amd_buffer_load_
v2
(
const
T
*
p_src_wave
,
amd_buffer_load_
invalid_element_return_return_zero
(
const
T
*
p_src_wave
,
index_t
src_thread_
data
_offset
,
index_t
src_thread_
element
_offset
,
bool
src_thread_
data
_valid
,
bool
src_thread_
element
_valid
,
index_t
src_element_space
)
index_t
src_element_space
_size
)
{
{
const
int32x4_t
src_wave_buffer_resource
=
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space
);
make_wave_buffer_resource
(
p_src_wave
,
src_element_space
_size
);
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_
data
_valid
?
0
:
0x7fffffff
;
uint32_t
src_addr_shift
=
src_thread_
element
_valid
?
0
:
0x7fffffff
;
return
amd_buffer_load_impl
_v2
<
scalar_t
,
vector_size
>
(
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#else
#else
vector_t
tmp
=
amd_buffer_load_impl
_v2
<
scalar_t
,
vector_size
>
(
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_
data
_valid
?
tmp
:
vector_t
(
0
);
return
src_thread_
element
_valid
?
tmp
:
vector_t
(
0
);
#endif
#endif
}
}
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type_maker
<
T
,
N
>::
type
::
type
amd_buffer_load_invalid_element_return_customized_value
(
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
bool
src_thread_element_valid
,
index_t
src_element_space_size
,
T
customized_value
)
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
customized_value
);
}
// buffer_store requires:
// buffer_store requires:
// 1) p_dst_wave must be global memory
// 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer.
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
__device__
void
__device__
void
amd_buffer_store
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
amd_buffer_store_v2
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
T
*
p_dst_wave
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
index_t
dst_thread_data_offset
,
const
bool
dst_thread_element_valid
,
const
bool
dst_thread_data_valid
,
const
index_t
dst_element_space_size
)
const
index_t
dst_element_space
)
{
{
const
int32x4_t
dst_wave_buffer_resource
=
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space
);
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space
_size
);
index_t
dst_thread_addr_offset
=
dst_thread_
data
_offset
*
sizeof
(
T
);
index_t
dst_thread_addr_offset
=
dst_thread_
element
_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_
data
_valid
?
0
:
0x7fffffff
;
uint32_t
dst_addr_shift
=
dst_thread_
element
_valid
?
0
:
0x7fffffff
;
amd_buffer_store_impl
_v2
<
scalar_t
,
vector_size
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
#else
if
(
dst_thread_
data
_valid
)
if
(
dst_thread_
element
_valid
)
{
{
amd_buffer_store_impl
_v2
<
scalar_t
,
vector_size
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
}
#endif
#endif
...
...
composable_kernel/include/utility/dynamic_buffer.hpp
View file @
a91b68df
...
@@ -6,34 +6,43 @@
...
@@ -6,34 +6,43 @@
namespace
ck
{
namespace
ck
{
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
,
bool
InvalidElementUseNumericalZeroValue
>
struct
DynamicBuffer
struct
DynamicBuffer
{
{
using
type
=
T
;
using
type
=
T
;
T
*
p_data_
;
T
*
p_data_
;
ElementSpaceSize
element_space_size_
;
ElementSpaceSize
element_space_size_
;
T
invalid_element_value_
=
T
{
0
};
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
{
{
}
}
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
},
invalid_element_value_
{
invalid_element_value
}
{
}
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpace
()
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpace
()
{
{
return
BufferAddressSpace
;
return
BufferAddressSpace
;
}
}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
template
<
typename
X
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_
offse
t
)
const
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_
elemen
t
)
const
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
constexpr
index_t
scalar_per_t_vector
=
...
@@ -45,20 +54,41 @@ struct DynamicBuffer
...
@@ -45,20 +54,41 @@ struct DynamicBuffer
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
"wrong! X need to be multiple T"
);
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
bool
constexpr
use_amd_buffer_addressing
=
true
;
return
amd_buffer_load_v2
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_offset
,
element_space_size_
);
#else
#else
return
is_valid_offset
?
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
}
;
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
#endif
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Global
&&
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
amd_buffer_load_invalid_element_return_return_zero
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
}
}
}
else
else
{
{
return
is_valid_offset
?
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
};
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
is_valid_element
?
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
};
}
else
{
return
is_valid_element
?
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
invalid_element_value_
};
}
}
}
}
}
...
@@ -67,7 +97,7 @@ struct DynamicBuffer
...
@@ -67,7 +97,7 @@ struct DynamicBuffer
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_
offse
t
,
const
X
&
x
)
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_
elemen
t
,
const
X
&
x
)
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
constexpr
index_t
scalar_per_t_vector
=
...
@@ -84,10 +114,10 @@ struct DynamicBuffer
...
@@ -84,10 +114,10 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
_v2
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
amd_buffer_store
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_
offse
t
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_
elemen
t
,
element_space_size_
);
#else
#else
if
(
is_valid_
offse
t
)
if
(
is_valid_
elemen
t
)
{
{
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
}
...
@@ -95,7 +125,7 @@ struct DynamicBuffer
...
@@ -95,7 +125,7 @@ struct DynamicBuffer
}
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
)
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
)
{
{
if
(
is_valid_
offse
t
)
if
(
is_valid_
elemen
t
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
...
@@ -185,7 +215,7 @@ struct DynamicBuffer
...
@@ -185,7 +215,7 @@ struct DynamicBuffer
}
}
else
else
{
{
if
(
is_valid_
offse
t
)
if
(
is_valid_
elemen
t
)
{
{
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
}
...
@@ -197,12 +227,18 @@ struct DynamicBuffer
...
@@ -197,12 +227,18 @@ struct DynamicBuffer
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
};
template
<
AddressSpaceEnum_t
BufferAddressSpace
=
AddressSpaceEnum_t
::
Generic
,
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
typename
T
,
typename
ElementSpaceSize
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
)
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
)
{
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
>
{
p
,
element_space_size
};
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
true
>
{
p
,
element_space_size
};
}
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
,
T
invalid_element_value
)
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
false
>
{
p
,
element_space_size
,
invalid_element_value
};
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/static_buffer.hpp
View file @
a91b68df
...
@@ -5,30 +5,66 @@
...
@@ -5,30 +5,66 @@
namespace
ck
{
namespace
ck
{
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
index_t
N
>
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
index_t
N
,
bool
InvalidElementUseNumericalZeroValue
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
{
using
type
=
T
;
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
T
invalid_element_value_
=
T
{
0
};
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
constexpr
StaticBuffer
(
T
invalid_element_value
)
:
base
{},
invalid_element_value_
{
invalid_element_value
}
{
}
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpace
()
__host__
__device__
static
constexpr
AddressSpaceEnum_t
GetAddressSpace
()
{
{
return
BufferAddressSpace
;
return
BufferAddressSpace
;
}
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
Get
(
Number
<
I
>
i
,
bool
is_valid_element
)
const
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
is_valid_element
?
At
(
i
)
:
T
{
0
};
}
else
{
return
is_valid_element
?
At
(
i
)
:
invalid_element_value_
;
}
}
template
<
index_t
I
>
__host__
__device__
void
Set
(
Number
<
I
>
i
,
bool
is_valid_element
,
const
T
&
x
)
{
if
(
is_valid_element
)
{
At
(
i
)
=
x
;
}
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
};
template
<
AddressSpaceEnum_t
BufferAddressSpace
=
AddressSpaceEnum_t
::
Generic
,
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
index_t
N
>
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
{
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
>
{};
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
,
true
>
{};
}
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
,
T
invalid_element_value
)
{
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
,
false
>
{
invalid_element_value
};
}
}
}
// namespace ck
}
// namespace ck
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
a91b68df
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
#define USE_MODE 1
#define USE_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R4R2_NHWC
0
#define USE_CONV_FWD_V4R4R2_NHWC
1
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment