Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c26c154e
Unverified
Commit
c26c154e
authored
Jul 14, 2023
by
rocking
Committed by
GitHub
Jul 14, 2023
Browse files
Merge branch 'develop' into avgpool_bwd
parents
0ab4fa0f
1ee99dca
Changes
155
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1573 additions
and
179 deletions
+1573
-179
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+401
-137
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
...device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
...mpl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
...norm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
+704
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp
...orm_multiblock/gridwise_multiblock_welford_first_half.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp
..._welford_second_half_batchnorm_forward_final_obsolete.hpp
+9
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+4
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+6
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+60
-2
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+61
-5
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+47
-10
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+63
-0
include/ck/utility/get_shift.hpp
include/ck/utility/get_shift.hpp
+20
-0
include/ck/utility/reduction_common.hpp
include/ck/utility/reduction_common.hpp
+0
-12
include/ck/utility/workgroup_synchronization.hpp
include/ck/utility/workgroup_synchronization.hpp
+74
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
...ry/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
...wd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
+119
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
c26c154e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
c26c154e
File moved
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
c26c154e
File moved
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
c26c154e
File moved
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
View file @
c26c154e
File moved
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
0 → 100644
View file @
c26c154e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp
View file @
c26c154e
...
...
@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
...
...
@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp
→
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final
_obsolete
.hpp
View file @
c26c154e
...
...
@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
AccDataType
epsilon
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
...
...
@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
mean_var_grid_desc_m
,
blkgroup_size
,
num_xy_k_block_tile_iteration
,
num_mean_var_count_k_block_tile_iteration
,
epsilon
,
p_in_welford_mean
,
p_in_welford_variance
,
...
...
@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
AccDataType
epsilon
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
...
...
@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
1
,
true
>
(
...
...
@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
1
,
true
>
(
...
...
@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_count
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
constexpr
auto
mean_var_count_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
*
1
);
// Step 1: do final welford reduction to get mean and variance
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
...
@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_count_thread_buf
(
I
)
=
0
;
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_mean_var_count_k_block_tile_iteration
;
++
reducedTiles
)
constexpr
auto
mean_var_count_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
);
int32_t
reducedSize
=
0
;
while
(
reducedSize
<
blkgroup_size
)
{
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_mean_global_val_buf
,
...
...
@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_var_thread_buf
,
welford_count_thread_buf
);
reducedSize
+=
KThreadClusterSize
;
threadwise_mean_var_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_thread_copy_step_m_k
);
threadwise_count_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
c26c154e
...
...
@@ -3,6 +3,8 @@
#pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
c26c154e
...
...
@@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2
do
{
#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
__builtin_amdgcn_iglp_opt
(
CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
);
#endif
block_sync_lds
();
// GEMM i
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
c26c154e
...
...
@@ -27,6 +27,9 @@ template <typename GridwiseGemm,
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__
((
amdgpu_waves_per_eu
(
CK_MIN_WAVES_PER_EU
,
CK_MAX_WAVES_PER_EU
)))
#endif
kernel_gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -60,6 +63,9 @@ template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__
((
amdgpu_waves_per_eu
(
CK_MIN_WAVES_PER_EU
,
CK_MAX_WAVES_PER_EU
)))
#endif
kernel_gemm_xdlops_v2r3
(
const
typename
GridwiseGemm
::
Argument
karg
)
{
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
c26c154e
...
...
@@ -29,7 +29,9 @@ enum struct MfmaInstr
mfma_i32_16x16x16i8
,
mfma_i32_32x32x16i8
,
mfma_i32_16x16x32i8
,
mfma_f64_16x16x4f64
mfma_f64_16x16x4f64
,
mfma_f32_32x32x16f8f8
,
mfma_f32_16x16x32f8f8
};
template
<
MfmaInstr
instr
>
...
...
@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32f8f8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
struct
MfmaSelector
{
...
...
@@ -594,6 +640,18 @@ struct MfmaSelector
}
#endif
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
__host__
__device__
constexpr
MfmaSelector
()
...
...
@@ -794,7 +852,7 @@ struct XdlopsGemm
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
,
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
c26c154e
...
...
@@ -13,6 +13,61 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
{
template
<
index_t
NDimSpatial
,
typename
ALayout
,
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
ConvBwdDataSpecialization
>
constexpr
auto
make_out_n_ho_wo_k_grid_desc
(
const
index_t
N
,
const
index_t
Ho
,
const
index_t
Wo
,
const
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_strides
)
{
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
)
{
const
index_t
NStride
=
out_g_n_k_wos_strides
[
1
];
const
index_t
HiStride
=
out_g_n_k_wos_strides
[
3
];
const
index_t
WiStride
=
out_g_n_k_wos_strides
[
4
];
const
auto
CStride
=
Number
<
1
>
{};
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
K
),
make_tuple
(
WiStride
,
CStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Ho
,
Wo
,
K
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
)
{
// assume packed
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
}
else
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
));
}
}
else
{
throw
std
::
runtime_error
(
"wrong! unsupported layout: "
+
ALayout
::
name
());
}
}
}
// namespace
template
<
index_t
NDimSpatial
,
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
ConvBwdDataSpecialization
,
...
...
@@ -29,11 +84,12 @@ struct TransformConvBwdDataToGemm_v1
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>,
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_AK0_M_AK1
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*
out_g_n_k_wos_strides
*/
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
...
...
@@ -70,9 +126,9 @@ struct TransformConvBwdDataToGemm_v1
const
index_t
AK0
=
K
/
AK1
;
// assume packed
const
auto
out_n_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
));
make_out_n_ho_wo_k_grid_desc
<
NDimSpatial
,
ALayout
,
ConvBwdDataSpecialization
>
(
N
,
Ho
,
Wo
,
K
,
out_g_n_k_wos_strides
);
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
...
...
@@ -80,7 +136,7 @@ struct TransformConvBwdDataToGemm_v1
{
// A: output tensor
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
))
,
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
c26c154e
...
...
@@ -1114,13 +1114,30 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
return
bit_cast
<
vector_t
>
(
tmp
);
}
else
{
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
}
#else
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
bit_cast
<
vector_t
>
(
tmp
)
:
vector_t
(
0
);
}
else
{
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
}
#endif
}
...
...
@@ -1179,13 +1196,33 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
}
else
{
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
}
#else
if
(
dst_thread_element_valid
)
{
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
else
{
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
}
#endif
}
...
...
include/ck/utility/amd_xdlops.hpp
View file @
c26c154e
...
...
@@ -354,5 +354,68 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
template
<
>
struct
intrin_mfma_f32_32x32x16f8f8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
float
reg_b_f32
=
type_convert
<
float
>
(
reg_b_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
intrin_mfma_f32_32x32x2f32
<
32
,
32
>::
Run
(
reg_a_f32
,
reg_b_f32
,
reg_c
);
});
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32f8f8
;
template
<
>
struct
intrin_mfma_f32_16x16x32f8f8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
float
reg_b_f32
=
type_convert
<
float
>
(
reg_b_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
intrin_mfma_f32_16x16x4f32
<
16
,
16
>::
Run
(
reg_a_f32
,
reg_b_f32
,
reg_c
);
});
#endif
}
};
}
// namespace ck
#endif
include/ck/utility/get_shift.hpp
0 → 100644
View file @
c26c154e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
template
<
index_t
N
>
static
constexpr
__device__
index_t
get_shift
()
{
return
(
get_shift
<
N
/
2
>
()
+
1
);
};
template
<
>
constexpr
__device__
index_t
get_shift
<
1
>
()
{
return
(
0
);
}
}
// namespace ck
include/ck/utility/reduction_common.hpp
View file @
c26c154e
...
...
@@ -25,16 +25,4 @@ struct float_equal_zero
};
};
template
<
index_t
N
>
static
constexpr
__device__
index_t
get_shift
()
{
return
(
get_shift
<
N
/
2
>
()
+
1
);
};
template
<
>
constexpr
__device__
index_t
get_shift
<
1
>
()
{
return
(
0
);
}
}
// namespace ck
include/ck/utility/workgroup_synchronization.hpp
0 → 100644
View file @
c26c154e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/host_utility/hip_check_error.hpp"
namespace
ck
{
// Initialization flag of Barrier object, can be any value except for zero
static
constexpr
int
BarrierInitFlag
=
0x7856
;
// 1) only the first thread-block in the synchronizaton group is supposed to call this function. It
// is the responsibility of the user to ensure the two integer values in p_control_bits are zeros
// before calling gms_init().
// 2) Aftercalling gms_reset(), the two integer values in p_control_bits will be zeros, so no
// repetitious initialization of p_control_bits buffer is required
static
__device__
void
gms_init
(
int
NumWarps
,
int
*
p_control_bits
)
{
union
{
int
two32
[
2
];
unsigned
long
one64
;
}
regs
;
regs
.
two32
[
0
]
=
BarrierInitFlag
;
regs
.
two32
[
1
]
=
NumWarps
;
if
(
threadIdx
.
x
==
0
)
atomicCAS
(
reinterpret_cast
<
unsigned
long
*>
(
p_control_bits
),
0
,
regs
.
one64
);
};
// all the workgroups in the synchronization group is supposed to call this function
static
__device__
void
gms_barrier
(
int
*
p_control_bits
)
{
constexpr
int
mask
=
warpSize
-
1
;
if
((
threadIdx
.
x
&
mask
)
==
0
)
{
// ensure the barrier object is initialized
do
{
const
int
r0
=
__atomic_load_n
(
&
p_control_bits
[
0
],
__ATOMIC_RELAXED
);
if
(
r0
==
BarrierInitFlag
)
break
;
}
while
(
true
);
// go ahead toward the barrier line
atomicSub
(
&
p_control_bits
[
1
],
1
);
// wait until all warps have arrived
do
{
const
int
r1
=
__atomic_load_n
(
&
p_control_bits
[
1
],
__ATOMIC_RELAXED
);
if
(
r1
==
0
)
break
;
}
while
(
true
);
};
};
// 1) Only the first thread-block in the synchronizaton group is supposed to call this function.
// 2) Aftercalling gms_reset(), the two integer values in p_control_bits will be zeros, so no
// repetitious initialization of p_control_bits buffer is required
static
__device__
void
gms_reset
(
int
*
p_control_bits
)
{
// reset the barrier object
if
(
threadIdx
.
x
==
0
)
(
void
)
atomicCAS
(
&
p_control_bits
[
0
],
BarrierInitFlag
,
0
);
};
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
View file @
c26c154e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
0 → 100644
View file @
c26c154e
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment