Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
15f1d4ad
Commit
15f1d4ad
authored
Dec 13, 2022
by
Anthony Chang
Browse files
compute y dot dy
parent
a3e487ca
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
331 additions
and
52 deletions
+331
-52
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+82
-36
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+233
-13
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+14
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
15f1d4ad
...
...
@@ -369,7 +369,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, n] = m
// ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
}
// calculate y & log-sum-exp beforehand
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
15f1d4ad
...
...
@@ -35,7 +35,7 @@ template <typename GridwiseGemm,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
C
GridDescriptor_MBlock_MPerBlock_
N
Block_
N
PerBlock
,
typename
Y
GridDescriptor_MBlock_MPerBlock_
O
Block_
O
PerBlock
,
typename
LSEGridDescriptor_M
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_M0_O_M1
,
...
...
@@ -65,7 +65,7 @@ __global__ void
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
C
GridDescriptor_MBlock_MPerBlock_
N
Block_
N
PerBlock
const
Y
GridDescriptor_MBlock_MPerBlock_
O
Block_
O
PerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
// const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args
...
...
@@ -329,6 +329,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO ANT: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
// 0, 1, .. NumDimG - 1
...
...
@@ -372,31 +374,21 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
}
template
<
typename
YGridDesc_M_O
,
typename
Number
>
static
auto
MakeYGradGridDescriptor_M0_O_M1
(
const
YGridDesc_M_O
&
ygrad_grid_desc_m_o
,
const
Number
&
M1
)
template
<
typename
YGridDesc_M_O
>
static
auto
MakeYGradGridDescriptor_M0_O_M1
(
const
YGridDesc_M_O
&
ygrad_grid_desc_m_o
)
{
const
auto
M
=
ygrad_grid_desc_m_o
.
GetLength
(
I0
);
const
auto
O
=
ygrad_grid_desc_m_o
.
GetLength
(
I1
);
const
auto
M0
=
M
/
M1
;
const
auto
Y_
M0
=
M
/
Y_
M1
;
return
transform_tensor_descriptor
(
ygrad_grid_desc_m_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_pass_through_transform
(
O
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Y_
M0
,
Y_
M1
)),
make_pass_through_transform
(
O
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// can we construct YGrad_m0_o_m1 from Y_m_o?
// static auto MakeYGradGridDescriptor_M0_O_M1(const std::vector<index_t>&
// y_gs_ms_os_lengths_vec,
// const std::vector<index_t>&
// y_gs_ms_os_strides_vec)
// {
// }
//
// dP = dY * V^T
//
...
...
@@ -410,6 +402,61 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides_vec
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO ANT: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
// 0, 1, .. NumDimG - 1
std
::
vector
<
index_t
>
gs_ids
(
NumDimG
);
std
::
iota
(
gs_ids
.
begin
(),
gs_ids
.
end
(),
0
);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std
::
vector
<
index_t
>
os_ids
(
NumDimO
);
std
::
iota
(
os_ids
.
begin
(),
os_ids
.
end
(),
NumDimG
);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std
::
vector
<
index_t
>
ns_ids
(
NumDimN
);
std
::
iota
(
ns_ids
.
begin
(),
ns_ids
.
end
(),
NumDimG
+
NumDimO
);
std
::
vector
<
index_t
>
ids_old2new
;
ids_old2new
.
insert
(
ids_old2new
.
end
(),
gs_ids
.
begin
(),
gs_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths_vec
(
num_dims
),
v_gs_ns_os_strides_vec
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths_vec
[
i
]
=
v_gs_os_ns_lengths_vec
[
id_new
];
v_gs_ns_os_strides_vec
[
i
]
=
v_gs_os_ns_strides_vec
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
// N_O to O0_N_O1; to refactor
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
// static auto MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock()
//
// dQ = alpha * dS * K
//
...
...
@@ -460,7 +507,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
C
GridDesc_M_
N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
Y
GridDesc_M_
O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
...
...
@@ -468,8 +515,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
VGradGridDesc_N_O
=
decltype
(
MakeVGradGridDescriptor_N_O
({},
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
CGridDesc_M_N
{},
Number
<
Y_M1
>
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
constexpr
static
auto
make_MaskOutPredicate
()
{
...
...
@@ -547,7 +593,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
C
GridDesc_M_
N
,
Y
GridDesc_M_
O
,
LSEGridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
...
...
@@ -647,7 +693,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c
_grid_desc_m_
n
_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
y
_grid_desc_m_
o
_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
// dV = P^T * dY
...
...
@@ -655,7 +701,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
/* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_
{
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
c
_grid_desc_m_
n_
,
Number
<
Y_M1
>
{}
)},
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
y
_grid_desc_m_
o_
)},
// batch offsets
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
...
...
@@ -665,8 +711,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c
_grid_desc_m_
n
_
)},
y
_grid_desc_mblock_mperblock_
o
block_
o
perblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
y
_grid_desc_m_
o
_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
...
...
@@ -704,12 +750,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
c
_grid_desc_m_
n
_
,
y
_grid_desc_m_
o
_
,
block_2_ctile_map_
))
{
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock_
=
y
_grid_desc_mblock_mperblock_
o
block_
o
perblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c
_grid_desc_m_
n
_
);
y
_grid_desc_m_
o
_
);
}
Print
();
}
...
...
@@ -754,14 +800,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
C
GridDesc_M_
N
c
_grid_desc_m_
n
_
;
Y
GridDesc_M_
O
y
_grid_desc_m_
o
_
;
LSEGridDesc_M
lse_grid_desc_m_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
typename
GridwiseGemm
::
C
GridDescriptor_MBlock_MPerBlock_
N
Block_
N
PerBlock
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock_
;
typename
GridwiseGemm
::
Y
GridDescriptor_MBlock_MPerBlock_
O
Block_
O
PerBlock
y
_grid_desc_mblock_mperblock_
o
block_
o
perblock_
;
VGradGridDesc_N_O
vgrad_grid_desc_n_o_
;
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1_
;
...
...
@@ -803,7 +849,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c
_grid_desc_m_
n
_
)
*
arg
.
batch_count_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y
_grid_desc_m_
o
_
)
*
arg
.
batch_count_
;
std
::
cout
<<
"grid size = "
<<
grid_size
<<
'\n'
;
// Gemm0_K
const
auto
K
=
...
...
@@ -824,7 +870,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
C
GridDescriptor_MBlock_MPerBlock_
N
Block_
N
PerBlock
,
typename
GridwiseGemm
::
Y
GridDescriptor_MBlock_MPerBlock_
O
Block_
O
PerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
VGradGridDesc_N_O
,
DeviceOp
::
YGradGridDesc_M0_O_M1
,
...
...
@@ -855,7 +901,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock_
,
arg
.
y
_grid_desc_mblock_mperblock_
o
block_
o
perblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
vgrad_grid_desc_n_o_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
...
...
@@ -909,8 +955,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
c
_grid_desc_m_
n
_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
c
_grid_desc_m_
n
_
.
GetLength
(
I1
);
const
index_t
c_m
=
arg
.
y
_grid_desc_m_
o
_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
y
_grid_desc_m_
o
_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
...
...
@@ -960,7 +1006,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c
_grid_desc_m_
n
_
,
arg
.
y
_grid_desc_m_
o
_
,
arg
.
block_2_ctile_map_
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
15f1d4ad
...
...
@@ -211,6 +211,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
VGradGemmTile_N_O_M
=
VGradGemmTile_N_O_M_
<>
;
// tune later
// PGrad Gemm
struct
PGradGemmTile_M_N_O_
{
};
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
static
constexpr
index_t
SrcScalarPerVetor
=
16
/
sizeof
(
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
// static_assert(BlockSliceLength_O_ % SrcScalarPerVetor == 0, "");
// static_assert(BlockSize_ % ThreadClusterLength_O == 0, "");
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
FloatGemmAcc
,
BlockSize_
,
Sequence
<
ThreadClusterLength_M
,
ThreadClusterLength_O
>
,
Sequence
<
0
,
1
>
,
reduce
::
Add
,
false
>
;
// propagateNaN
// using ThreadReduceSrcDesc_M_O = decltype(make_naive_tensor_descriptor_packed(
// make_tuple(ThreadSliceLength_M, ThreadSliceLength_O)));
// using ThreadReduceDstDesc_M =
// decltype(make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceLength_M)));
// using ThreadwiseSumReduce =
// ThreadwiseReduction<FloatGemmAcc,
// ThreadReduceSrcDesc_M_O,
// ThreadReduceDstDesc_M,
// reduce::Add,
// false>; // propagateNaN
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
,
ThreadSliceLength_M
,
true
>
;
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
Gemm1NPerBlock
>
;
// QGrad Gemm
// KGrad Gemm
...
...
@@ -402,14 +452,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
Gemm1NPerBlock
;
const
auto
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock
=
transform_tensor_descriptor
(
const
auto
y
_grid_desc_mblock_mperblock_
o
block_
o
perblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock
;
return
y
_grid_desc_mblock_mperblock_
o
block_
o
perblock
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -437,7 +487,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_m_n
);
}
using
C
GridDescriptor_MBlock_MPerBlock_
N
Block_
N
PerBlock
=
remove_cvref_t
<
decltype
(
using
Y
GridDescriptor_MBlock_MPerBlock_
O
Block_
O
PerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
...
...
@@ -497,7 +547,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_
c
_grid
,
const
DataType
*
__restrict__
p_
y
_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
...
...
@@ -512,8 +562,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
C
GridDescriptor_MBlock_MPerBlock_
N
Block_
N
PerBlock
&
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock
,
const
Y
GridDescriptor_MBlock_MPerBlock_
O
Block_
O
PerBlock
&
y
_grid_desc_mblock_mperblock_
o
block_
o
perblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
VGradGridDescriptor_N_O
&
vgrad_grid_desc_n_o
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
...
...
@@ -526,8 +576,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
c
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
c
_grid
,
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock
.
GetElementSpaceSize
());
const
auto
y
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
y
_grid
,
y
_grid_desc_mblock_mperblock_
o
block_
o
perblock
.
GetElementSpaceSize
());
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -535,14 +585,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
vgrad_grid_desc_n_o
.
GetElementSpaceSize
());
// divide block work by [M,
N
]
// divide block work by [M,
O
]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock
.
GetLength
(
I0
),
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock
.
GetLength
(
I2
))))
make_tuple
(
y
_grid_desc_mblock_mperblock_
o
block_
o
perblock
.
GetLength
(
I0
),
y
_grid_desc_mblock_mperblock_
o
block_
o
perblock
.
GetLength
(
I2
))))
{
return
;
}
...
...
@@ -1217,7 +1267,175 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index(NPerBlock))[I0]);
}
#endif
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
VGradGemmTile_N_O_M
::
Sum_M
;
//
// dP
//
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_O
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{});
const
auto
y_thread_cluster_idx
=
y_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
make_multi_index
(
block_work_idx
[
I0
],
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
y_thread_data_on_block_idx
;
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
YDotYGrad_M_O
::
SrcScalarPerVetor
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
true
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_dot_ygrad_thread_accum_buf
=
typename
YDotYGrad_M_O
::
DstBufType
{};
auto
y_dot_ygrad_block_accum_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
),
MPerBlock
);
constexpr
auto
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
),
make_tuple
(
P_M0
*
P_M1
*
P_M2
,
P_M1
*
P_M2
,
P_M2
,
I1
));
// y_dot_ygrad thread buffer for calculating sgrad; reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is tiled the same way
constexpr
auto
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
=
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
;
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
),
Sequence
<
1
,
m0
,
m1
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
m2
,
1
,
false
>
{
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
I0
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
// clear accum buffers
y_dot_ygrad_thread_accum_buf
.
Clear
();
y_dot_ygrad_block_accum_buf
.
Clear
();
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds before accum\n");
if(hipBlockIdx_x == 0)
{
debug::print_shared(y_dot_ygrad_block_accum_buf.p_data_, MPerBlock);
}
#endif
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
//
// calculate y dot ygrad
//
index_t
oblock_idx
=
0
;
do
{
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_buf
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_thread_buf
);
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
ygrad_grid_buf
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
ygrad_thread_buf
);
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_O
,
1
>
{}([
&
](
auto
iO
)
{
constexpr
auto
offset
=
y_thread_desc_m0_m1_o0_o1
.
CalculateOffset
(
make_multi_index
(
I0
,
iM
,
I0
,
iO
));
y_dot_ygrad_thread_accum_buf
(
iM
)
+=
y_thread_buf
[
Number
<
offset
>
{}]
*
ygrad_thread_buf
[
Number
<
offset
>
{}];
});
});
#if 0
if (hipThreadIdx_x % 32 < 4 && hipBlockIdx_x == 0)
{
printf("bid %zd tid %zd, oblock_idx %d, y_thread_buf[0:3] = %f %f %f %f, ygrad_thread_buf[0:3] = %f %f %f %f\n",
hipBlockIdx_x,
hipThreadIdx_x,
oblock_idx,
(float)y_thread_buf[I0],
(float)y_thread_buf[I1],
(float)y_thread_buf[I2],
(float)y_thread_buf[I3],
(float)ygrad_thread_buf[I0],
(float)ygrad_thread_buf[I1],
(float)ygrad_thread_buf[I2],
(float)ygrad_thread_buf[I3]);
}
#endif
yygrad_threadwise_copy
.
MoveSrcSliceWindow
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
make_multi_index
(
0
,
0
,
1
,
0
));
oblock_idx
++
;
}
while
(
oblock_idx
<
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetLength
(
I2
));
// blockwise reduction using atomic_add
block_sync_lds
();
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]);
});
block_sync_lds
();
#if 1
if
(
hipThreadIdx_x
==
0
&&
hipBlockIdx_x
==
0
)
printf
(
"lds after accum
\n
"
);
if
(
hipBlockIdx_x
==
0
)
{
debug
::
print_shared
(
y_dot_ygrad_block_accum_buf
.
p_data_
,
MPerBlock
);
}
#endif
// distribute to threads
y_dot_ygrad_thread_copy_lds_to_vgpr
.
Run
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
y_dot_ygrad_block_accum_buf
,
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_dot_ygrad_thread_buf
);
#if 0
if(hipBlockIdx_x < 4 && hipThreadIdx_x % 32 < 4)
{
printf("bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
y_thread_data_on_grid_idx[I0],
y_thread_data_on_grid_idx[I1],
y_thread_data_on_grid_idx[I2],
y_thread_data_on_grid_idx[I3]);
}
#endif
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_buf
,
...
...
@@ -1348,6 +1566,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
SubThreadBlock
<
BlockSize
>
p_thread_copy_subgroup
(
blockwise_gemm
.
GetWaveIdx
()[
I0
],
blockwise_gemm
.
GetWaveIdx
()[
I1
]);
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
VGradGemmTile_N_O_M
::
Sum_M
;
static_assert
(
sfc_p_m0_n0_m1_n1_m2_n2
.
GetNumOfAccess
()
==
num_vgrad_gemm_loop
,
""
);
vgrad_acc_thread_buf
.
Clear
();
...
...
@@ -1450,7 +1669,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO ANT:
// shuffle dQ and write
if
constexpr
(
false
)
#
if
0
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
...
...
@@ -1646,6 +1865,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
});
}
#endif
}
};
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
15f1d4ad
...
...
@@ -143,6 +143,16 @@ struct DynamicBuffer
}
}
__host__
__device__
void
Clear
()
{
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong! only local data share is supported"
);
for
(
index_t
i
=
get_thread_local_1d_id
();
i
<
element_space_size_
;
i
+=
get_block_size
())
{
Set
(
i
,
true
,
T
{
0
});
}
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
...
...
@@ -302,7 +312,9 @@ struct DynamicBuffer
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
"only support global mem"
);
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
||
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"only support global mem or local data share"
);
#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
...
...
@@ -319,7 +331,7 @@ struct DynamicBuffer
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
if
constexpr
(
use_amd_buffer_addressing
&&
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment