Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0b11569f
Commit
0b11569f
authored
Jul 01, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute
parents
e8d3a0fb
fa9a0a5c
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
325 additions
and
244 deletions
+325
-244
include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp
...or_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+102
-99
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp
...ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp
...e/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp
...e/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+87
-84
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
...k/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
+85
-61
include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
...nsor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
+3
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
@@ -20,19 +23,19 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC1
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D
GridDescriptor_MBlock_MPerBlock
,
typename
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
...
...
@@ -43,15 +46,15 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_
c0
_grid
,
const
FloatC1
*
__restrict__
p_
c1
_grid
,
D
PtrsGlobal
p_
d
s_grid
,
const
FloatC0
*
__restrict__
p_
bias
_grid
,
const
FloatC1
*
__restrict__
p_
d0
_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
C1ElementwiseOperation
c1_element_op
,
const
Dxs
InElementwiseOperation
dxs
_in_element_op
,
const
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
,
const
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
const
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -60,7 +63,7 @@ __global__ void
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D
GridDescriptor_MBlock_MPerBlock
d
_grid_desc_mblock_mperblock
,
const
Reduce
GridDescriptor_MBlock_MPerBlock
reduce
_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -69,42 +72,42 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_
c0
_grid
,
p_
c1
_grid
,
p_
d
s_grid
,
p_
bias
_grid
,
p_
d0
_grid
,
p_
reduce
s_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
c1_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
d
_grid_desc_mblock_mperblock
,
reduce
_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_
c0
_grid
;
ignore
=
p_
c1
_grid
;
ignore
=
p_
d
s_grid
;
ignore
=
p_
bias
_grid
;
ignore
=
p_
d0
_grid
;
ignore
=
p_
reduce
s_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c1_element_op
;
ignore
=
dxs
_in_element_op
;
ignore
=
dxs
_out_element_op
;
ignore
=
reduce
_in_element_op
s
;
ignore
=
reduce
_out_element_op
s
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c0_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c1_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d
_grid_desc_mblock_mperblock
;
ignore
=
reduce
_grid_desc_mblock_mperblock
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
@@ -116,22 +119,22 @@ template <typename FloatAB,
typename
FloatC0
,
typename
FloatC1
,
typename
FloatReduceAcc
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
Reduce
GlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
C0GridDesc_M_N
,
typename
C1GridDesc_M_N
,
typename
D
GridDesc_M
,
typename
Reduce
GridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -318,18 +321,18 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
__host__
__device__
static
constexpr
auto
Make
D
GridDescriptor_MBlock_MPerBlock
(
const
D
GridDesc_M
&
d_grid_desc_m
)
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
const
Reduce
GridDesc_M
&
d_grid_desc_m
)
{
const
auto
M
=
d_grid_desc_m
.
GetLength
(
I0
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
d
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
const
auto
reduce
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
d_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
return
d
_grid_desc_mblock_mperblock
;
return
reduce
_grid_desc_mblock_mperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
...
...
@@ -349,36 +352,37 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
C1GridDesc_M_N
{}))
>
;
using
D
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
D
GridDescriptor_MBlock_MPerBlock
(
D
GridDesc_M
{}))
>
;
using
Reduce
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
Reduce
GridDesc_M
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
DPtrsGlobal
p_ds_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
C1ElementwiseOperation
&
c1_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsReduceAccElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
&
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_bias_grid
,
const
FloatC1
*
__restrict__
p_d0_grid
,
ReducePtrsGlobal
p_reduces_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
C1ElementwiseOperation
&
c1_element_op
,
const
ReduceInElementwiseOperations
&
reduce_in_element_ops
,
const
ReduceAccElementwiseOperations
&
reduce_out_element_ops
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
const
C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c1_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ReduceGridDescriptor_MBlock_MPerBlock
&
reduce_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -387,9 +391,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
c0
_grid
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_
bias
_grid
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
c1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
c1
_grid
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_
d0
_grid
,
c1_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
...
...
@@ -722,12 +726,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{},
Number
<
nreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mperblock
constexpr
auto
d_
reduce_thread_desc_mperblock
=
// VGPR reduce_thread_desc_mperblock
constexpr
auto
reduce_thread_desc_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mblock_mperblock
constexpr
auto
d_
reduce_thread_desc_mblock_mperblock
=
// VGPR reduce_thread_desc_mblock_mperblock
constexpr
auto
reduce_thread_desc_mblock_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
...
...
@@ -756,29 +760,29 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1
,
true
>
{
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_data_idx_begin
};
auto
dxs_
reduce_thread_copy_vgpr_to_global
=
generate_tuple
(
auto
reduce_
tuple_
thread_copy_vgpr_to_global
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
p_
d
_grid
=
p_
d
s_grid
[
I
];
auto
d_out
_element_op
=
dxs
_out_element_op
[
I
];
auto
p_
reduce
_grid
=
p_
reduce
s_grid
[
I
];
auto
reduce_acc
_element_op
=
reduce
_out_element_op
s
[
I
];
return
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
remove_pointer_t
<
decltype
(
p_
d
_grid
)
>
,
decltype
(
d_
reduce_thread_desc_mblock_mperblock
),
decltype
(
d
_grid_desc_mblock_mperblock
),
decltype
(
d_out
_element_op
),
remove_pointer_t
<
decltype
(
p_
reduce
_grid
)
>
,
decltype
(
reduce_thread_desc_mblock_mperblock
),
decltype
(
reduce
_grid_desc_mblock_mperblock
),
decltype
(
reduce_acc
_element_op
),
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
0
,
1
>
,
1
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
D
GlobalMemoryDataOperation
::
At
(
I
),
Reduce
GlobalMemoryDataOperation
::
At
(
I
),
1
,
false
>
{
d
_grid_desc_mblock_mperblock
,
false
>
{
reduce
_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
d_out
_element_op
};
reduce_acc
_element_op
};
},
Number
<
p_
d
s_grid
.
Size
()
>
{});
Number
<
p_
reduce
s_grid
.
Size
()
>
{});
// c0 and c1
constexpr
auto
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -906,35 +910,35 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
static_for
<
0
,
p_
d
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
d
_grid
=
p_
d
s_grid
[
In
];
static_for
<
0
,
p_
reduce
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
reduce
_grid
=
p_
reduce
s_grid
[
In
];
auto
d
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
d
_grid
,
d
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
reduce
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
reduce
_grid
,
reduce
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
d
_thread_buf
=
auto
reduce
_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
d_
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
&
d
_in_element_op
=
dxs
_in_element_op
[
In
];
auto
&
reduce
_in_element_op
=
reduce
_in_element_op
s
[
In
];
auto
&
d_
reduce_thread_copy_vgpr_to_global
=
dxs_
reduce_thread_copy_vgpr_to_global
(
In
);
auto
&
reduce_thread_copy_vgpr_to_global
=
reduce_
tuple_
thread_copy_vgpr_to_global
(
In
);
using
D
ReduceOperation
=
remove_cvref_t
<
decltype
(
Dxs
ReduceOperation
{}[
In
])
>
;
using
ReduceOperation
=
remove_cvref_t
<
decltype
(
ReduceOperation
s
{}[
In
])
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_
reduce_thread_desc_mperblock
),
D
ReduceOperation
,
decltype
(
reduce_thread_desc_mperblock
),
ReduceOperation
,
false
>
;
// Global write Gemm shuffle + reduction
const
auto
d_zero
Val
=
D
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
const
auto
reduce_identity
Val
=
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d
_thread_buf
(
I
)
=
d_zero
Val
;
});
[
&
](
auto
I
)
{
reduce
_thread_buf
(
I
)
=
reduce_identity
Val
;
});
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
...
...
@@ -943,26 +947,25 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
d
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
reduce
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
});
});
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
d
_thread_buf
);
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
reduce
_thread_buf
);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d_thread_buf
,
d_grid_desc_mblock_mperblock
,
d_grid_buf
);
reduce_thread_copy_vgpr_to_global
.
Run
(
reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
reduce_thread_buf
,
reduce_grid_desc_mblock_mperblock
,
reduce_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
d_
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d
_grid_desc_mblock_mperblock
,
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
reduce
_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_GEMM_V2_HPP
#define CK_GRIDWISE_GEMM_V2_HPP
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_GEMM_V3_HPP
#define CK_GRIDWISE_GEMM_V3_HPP
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
@@ -18,16 +21,16 @@ namespace ck {
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D
GridDescriptor_MBlock_MPerBlock
,
typename
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
...
...
@@ -38,17 +41,17 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
D
PtrsGlobal
p_
d
s_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Dxs
InElementwiseOperation
dxs
_in_element_op
,
const
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
,
const
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
const
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D
GridDescriptor_MBlock_MPerBlock
d
_grid_desc_mblock_mperblock
,
const
Reduce
GridDescriptor_MBlock_MPerBlock
reduce
_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -57,32 +60,32 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_
d
s_grid
,
p_
reduce
s_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs
_in_element_op
,
dxs
_out_element_op
,
reduce
_in_element_op
s
,
reduce
_out_element_op
s
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d
_grid_desc_mblock_mperblock
,
reduce
_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_
d
s_grid
;
ignore
=
p_
reduce
s_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
dxs
_in_element_op
;
ignore
=
dxs
_out_element_op
;
ignore
=
reduce
_in_element_op
s
;
ignore
=
reduce
_out_element_op
s
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d
_grid_desc_mblock_mperblock
;
ignore
=
reduce
_grid_desc_mblock_mperblock
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
@@ -92,19 +95,19 @@ template <typename FloatAB,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatReduceAcc
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
ReduceOperation
s
,
typename
Reduce
InElementwiseOperation
s
,
typename
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
D
GlobalMemoryDataOperation
,
typename
Reduce
GlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
D
GridDesc_M
,
typename
Reduce
GridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -290,18 +293,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
__host__
__device__
static
constexpr
auto
Make
D
GridDescriptor_MBlock_MPerBlock
(
const
D
GridDesc_M
&
d_grid_desc_m
)
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
const
Reduce
GridDesc_M
&
d_grid_desc_m
)
{
const
auto
M
=
d_grid_desc_m
.
GetLength
(
I0
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
d
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
const
auto
reduce
_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
d_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
return
d
_grid_desc_mblock_mperblock
;
return
reduce
_grid_desc_mblock_mperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
...
...
@@ -315,29 +318,30 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
D
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
D
GridDescriptor_MBlock_MPerBlock
(
D
GridDesc_M
{}))
>
;
using
Reduce
GridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
Reduce
GridDesc_M
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
DPtrsGlobal
p_ds_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsReduceAccElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
&
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
ReducePtrsGlobal
p_reduces_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
ReduceInElementwiseOperations
&
reduce_in_element_ops
,
const
ReduceAccElementwiseOperations
&
reduce_out_element_ops
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ReduceGridDescriptor_MBlock_MPerBlock
&
reduce_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -703,12 +707,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{},
Number
<
nreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mperblock
constexpr
auto
d_
reduce_thread_desc_mperblock
=
// VGPR reduce_thread_desc_mperblock
constexpr
auto
reduce_thread_desc_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{}));
// VGPR
d_
reduce_thread_desc_mblock_mperblock
constexpr
auto
d_
reduce_thread_desc_mblock_mperblock
=
// VGPR reduce_thread_desc_mblock_mperblock
constexpr
auto
reduce_thread_desc_mblock_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
...
...
@@ -737,29 +741,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1
,
true
>
{
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_data_idx_begin
};
auto
dxs_
reduce_thread_copy_vgpr_to_global
=
generate_tuple
(
auto
reduce_
tuple_
thread_copy_vgpr_to_global
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
p_
d
_grid
=
p_
d
s_grid
[
I
];
auto
d_out
_element_op
=
dxs
_out_element_op
[
I
];
auto
p_
reduce
_grid
=
p_
reduce
s_grid
[
I
];
auto
reduce_acc
_element_op
=
reduce
_out_element_op
s
[
I
];
return
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
remove_pointer_t
<
decltype
(
p_
d
_grid
)
>
,
decltype
(
d_
reduce_thread_desc_mblock_mperblock
),
decltype
(
d
_grid_desc_mblock_mperblock
),
decltype
(
d_out
_element_op
),
remove_pointer_t
<
decltype
(
p_
reduce
_grid
)
>
,
decltype
(
reduce_thread_desc_mblock_mperblock
),
decltype
(
reduce
_grid_desc_mblock_mperblock
),
decltype
(
reduce_acc
_element_op
),
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
0
,
1
>
,
1
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
D
GlobalMemoryDataOperation
::
At
(
I
),
Reduce
GlobalMemoryDataOperation
::
At
(
I
),
1
,
false
>
{
d
_grid_desc_mblock_mperblock
,
false
>
{
reduce
_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
d_out
_element_op
};
reduce_acc
_element_op
};
},
Number
<
p_
d
s_grid
.
Size
()
>
{});
Number
<
p_
reduce
s_grid
.
Size
()
>
{});
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
...
...
@@ -794,35 +798,35 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple
(
I0
,
I0
),
c_reduce_thread_buf
);
static_for
<
0
,
p_
d
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
d
_grid
=
p_
d
s_grid
[
In
];
static_for
<
0
,
p_
reduce
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_
reduce
_grid
=
p_
reduce
s_grid
[
In
];
auto
d
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
d
_grid
,
d
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
reduce
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
reduce
_grid
,
reduce
_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
d
_thread_buf
=
auto
reduce
_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
d_
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
&
d
_in_element_op
=
dxs
_in_element_op
[
In
];
auto
&
reduce
_in_element_op
=
reduce
_in_element_op
s
[
In
];
auto
&
d_
reduce_thread_copy_vgpr_to_global
=
dxs_
reduce_thread_copy_vgpr_to_global
(
In
);
auto
&
reduce_thread_copy_vgpr_to_global
=
reduce_
tuple_
thread_copy_vgpr_to_global
(
In
);
using
D
ReduceOperation
=
remove_cvref_t
<
decltype
(
Dxs
ReduceOperation
{}[
In
])
>
;
using
ReduceOperation
=
remove_cvref_t
<
decltype
(
ReduceOperation
s
{}[
In
])
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_
reduce_thread_desc_mperblock
),
D
ReduceOperation
,
decltype
(
reduce_thread_desc_mperblock
),
ReduceOperation
,
false
>
;
// Global write Gemm shuffle + reduction
const
auto
d
_identityVal
=
D
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
const
auto
reduce
_identityVal
=
ReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d
_thread_buf
(
I
)
=
d
_identityVal
;
});
[
&
](
auto
I
)
{
reduce
_thread_buf
(
I
)
=
reduce
_identityVal
;
});
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
...
...
@@ -831,26 +835,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
d
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
reduce
_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
});
});
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
d
_thread_buf
);
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
reduce
_thread_buf
);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d_thread_buf
,
d_grid_desc_mblock_mperblock
,
d_grid_buf
);
reduce_thread_copy_vgpr_to_global
.
Run
(
reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
reduce_thread_buf
,
reduce_grid_desc_mblock_mperblock
,
reduce_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
d_
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d
_grid_desc_mblock_mperblock
,
reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
reduce
_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
...
...
@@ -46,7 +49,8 @@ template <typename InDataType,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
index_t
OutDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseSoftmax_mk_to_mk
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
...
...
@@ -72,19 +76,6 @@ struct GridwiseSoftmax_mk_to_mk
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
>
;
// PropagateNan
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
>
;
// PropagateNan
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -102,6 +93,11 @@ struct GridwiseSoftmax_mk_to_mk
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
)
{
if
constexpr
(
SweepOnce
)
{
num_k_block_tile_iteration
=
1
;
}
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
...
...
@@ -146,6 +142,20 @@ struct GridwiseSoftmax_mk_to_mk
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
//
// NOTE: reset coordinate after every step because the same threadwise copy will sweep
// through global memory 3 times back and forth
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
GridDesc_M_K
,
...
...
@@ -155,7 +165,8 @@ struct GridwiseSoftmax_mk_to_mk
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
true
/* ResetCoordAfterRun */
,
true
/* InvalidElementAsNaN */
>
(
in_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
...
...
@@ -195,21 +206,39 @@ struct GridwiseSoftmax_mk_to_mk
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
in_thread_copy_fwd_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
constexpr
auto
in_thread_copy_bwd_step
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
constexpr
auto
in_thread_copy_fwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
in_thread_copy_bwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
///
/// max(x)
///
const
auto
in_global_val_buf_oob_non_zero
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
reduce
::
Max
::
template
GetIdentityValue
<
InDataType
>());
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
const
auto
in_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
index_t
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
_oob_non_zero
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
...
...
@@ -229,26 +258,6 @@ struct GridwiseSoftmax_mk_to_mk
///
/// sum(exp(x - max(x)))
///
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
const
auto
in_global_val_buf_oob_nan
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
NumericLimits
<
InDataType
>::
QuietNaN
());
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
...
...
@@ -269,22 +278,25 @@ struct GridwiseSoftmax_mk_to_mk
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
if
constexpr
(
!
SweepOnce
)
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
}
// do element-wise pre-reduction operation
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in
_thread_buf
(
Number
<
offset
>
{})
=
out
_thread_buf
(
Number
<
offset
>
{})
=
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
));
});
});
ThreadwiseSumReduce
::
Reduce
(
in
_thread_buf
,
accu_value_buf
);
ThreadwiseSumReduce
::
Reduce
(
out
_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
...
...
@@ -306,11 +318,14 @@ struct GridwiseSoftmax_mk_to_mk
{
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
if
constexpr
(
!
SweepOnce
)
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
...
...
@@ -337,18 +352,27 @@ struct GridwiseSoftmax_mk_to_mk
}
else
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_prior_dst_buf
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
if
constexpr
(
!
SweepOnce
)
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
}
threadwise_dst_load
.
Run
(
out_grid_desc_m_k
,
out_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
out_thread_buf
);
in_prior_dst_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
...
...
@@ -357,7 +381,7 @@ struct GridwiseSoftmax_mk_to_mk
out_thread_buf
(
Number
<
offset
>
{})
=
alpha
*
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
))
/
accu_value_buf
(
iM
)
+
beta
*
out_thread
_buf
(
Number
<
offset
>
{});
beta
*
in_prior_dst
_buf
(
Number
<
offset
>
{});
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment