Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7a302cc9
Commit
7a302cc9
authored
Dec 20, 2022
by
Anthony Chang
Browse files
debugging dQ; suspected K mat not properly loaded
parent
b637c77d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
368 additions
and
97 deletions
+368
-97
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+18
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+350
-95
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
7a302cc9
...
@@ -104,7 +104,7 @@ using DeviceGemmInstance =
...
@@ -104,7 +104,7 @@ using DeviceGemmInstance =
TensorSpecY
,
TensorSpecY
,
1
,
1
,
256
,
256
,
256
,
// MPerBlock
128
,
// MPerBlock
128
,
// NPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
32
,
// KPerBlock
64
,
// Gemm1NPerBlock
64
,
// Gemm1NPerBlock
...
@@ -114,7 +114,7 @@ using DeviceGemmInstance =
...
@@ -114,7 +114,7 @@ using DeviceGemmInstance =
2
,
// B1K1
2
,
// B1K1
32
,
// MPerXDL
32
,
// MPerXDL
32
,
// NPerXDL
32
,
// NPerXDL
2
,
// MXdlPerWave
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
4
,
64
,
1
>
,
// ABlockTransfer
...
@@ -375,18 +375,34 @@ int run(int argc, char* argv[])
...
@@ -375,18 +375,34 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break
;
break
;
default:
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
// dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
}
// calculate y & log-sum-exp beforehand
// calculate y & log-sum-exp beforehand
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
7a302cc9
...
@@ -28,8 +28,8 @@ template <typename DataType,
...
@@ -28,8 +28,8 @@ template <typename DataType,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
A
GridDesc_
A
K0_M_
A
K1
,
typename
Q
GridDesc_K0_M_K1
,
typename
B
GridDesc_
B
K0_N_
B
K1
,
typename
K
GridDesc_K0_N_K1
,
typename
VGridDesc_N0_O_N1
,
typename
VGridDesc_N0_O_N1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
LSEGridDesc_M
,
typename
LSEGridDesc_M
,
...
@@ -335,8 +335,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -335,8 +335,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
A
GridDesc_
A
K0_M_
A
K1
&
a
_grid_desc_
a
k0_m_
a
k1
,
CheckValidity
(
const
Q
GridDesc_K0_M_K1
&
q
_grid_desc_k0_m_k1
,
const
B
GridDesc_
B
K0_N_
B
K1
&
b
_grid_desc_
b
k0_n_
b
k1
,
const
K
GridDesc_K0_N_K1
&
k
_grid_desc_k0_n_k1
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
...
@@ -345,9 +345,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -345,9 +345,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
a
_grid_desc_
a
k0_m_
a
k1
.
GetLength
(
I1
);
const
auto
M
=
q
_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b
_grid_desc_
b
k0_n_
b
k1
.
GetLength
(
I1
);
const
auto
N
=
k
_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K
=
a
_grid_desc_
a
k0_m_
a
k1
.
GetLength
(
I0
)
*
a
_grid_desc_
a
k0_m_
a
k1
.
GetLength
(
I2
);
const
auto
K
=
q
_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q
_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
const
auto
Gemm1N
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
...
@@ -446,7 +446,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -446,7 +446,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
// PGrad Gemm has the same layout as P Gemm (A row-major B col-major)
// PGrad Gemm has the same layout as P
= Q * K^T
Gemm (A row-major B col-major)
struct
PGradGemmTile_M_N_O
struct
PGradGemmTile_M_N_O
{
{
private:
private:
...
@@ -521,6 +521,82 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -521,6 +521,82 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
};
};
// QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major)
struct
QGradGemmTile_M_K_N
{
template
<
typename
QGridDesc_K0_M_K1_
>
__device__
static
const
auto
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
{
const
auto
K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
K1
=
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
K
=
K0
*
K1
;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
KBlock
=
K
/
Gemm1NPerBlock
;
// NOTE: QGrad gemm is similar to Y gemm
const
auto
q_grid_desc_m_k
=
transform_tensor_descriptor
(
q_grid_desc_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
M
),
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
q_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
KBlock
,
Number
<
Gemm1NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
}
template
<
typename
SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_
>
__device__
static
const
auto
MakeSGradThreadDesc_N0_M_N1
(
const
SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_
&
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
)
{
constexpr
auto
m0
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
m1
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
n1
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
m2
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
n2
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
n3
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
sgrad_thread_desc_n0_m_n1
=
transform_tensor_descriptor
(
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
sgrad_thread_desc_n0_m_n1
;
}
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
const
auto
MakeKGridDesc_N0_K_N1
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K_K1
=
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
constexpr
auto
K_N1
=
BK1
;
const
auto
K_N0
=
N
/
K_N1
;
const
auto
k_grid_desc_n0_k_n1
=
transform_tensor_descriptor
(
k_grid_desc_k0_n_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K_N0
,
K_N1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
K_K0
,
K_K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
k_grid_desc_n0_k_n1
;
}
};
struct
SharedMemTrait
struct
SharedMemTrait
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -572,8 +648,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -572,8 +648,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
VGradGridDescriptor_N_O
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_M0_O_M1
>
typename
YGradGridDesc_M0_O_M1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_
a
_grid
,
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_
q
_grid
,
const
DataType
*
__restrict__
p_
b
_grid
,
const
DataType
*
__restrict__
p_
k
_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_v_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
...
@@ -587,8 +663,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -587,8 +663,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
AccElementwiseOperation
&
acc_element_op
,
const
AccElementwiseOperation
&
acc_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
A
GridDesc_
A
K0_M_
A
K1
&
a
_grid_desc_
a
k0_m_
a
k1
,
const
Q
GridDesc_K0_M_K1
&
q
_grid_desc_k0_m_k1
,
const
B
GridDesc_
B
K0_N_
B
K1
&
b
_grid_desc_
b
k0_n_
b
k1
,
const
K
GridDesc_K0_N_K1
&
k
_grid_desc_k0_n_k1
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
...
@@ -598,10 +674,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -598,10 +674,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
const
C0MatrixMask
&
c0_matrix_mask
)
{
{
const
auto
a
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
q
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
a
_grid
,
a
_grid_desc_
a
k0_m_
a
k1
.
GetElementSpaceSize
());
p_
q
_grid
,
q
_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
k
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
b
_grid
,
b
_grid_desc_
b
k0_n_
b
k1
.
GetElementSpaceSize
());
p_
k
_grid
,
k
_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
const
auto
v_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
v_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_v_grid
,
v_grid_desc_n0_o_n1
.
GetElementSpaceSize
());
p_v_grid
,
v_grid_desc_n0_o_n1
.
GetElementSpaceSize
());
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -612,6 +688,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -612,6 +688,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
vgrad_grid_desc_n_o
.
GetElementSpaceSize
());
p_vgrad_grid
,
vgrad_grid_desc_n_o
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
// divide block work by [M, O]
// divide block work by [M, O]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -653,7 +731,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -653,7 +731,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
DataType
,
DataType
,
decltype
(
a
_grid_desc_
a
k0_m_
a
k1
),
decltype
(
q
_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
...
@@ -666,7 +744,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -666,7 +744,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true
,
// SrcResetCoord
true
,
// SrcResetCoord
true
,
// DstResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
a
_grid_desc_
a
k0_m_
a
k1
,
q
_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
...
@@ -684,7 +762,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -684,7 +762,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
DataType
,
DataType
,
decltype
(
b
_grid_desc_
b
k0_n_
b
k1
),
decltype
(
k
_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
...
@@ -697,7 +775,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -697,7 +775,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true
,
// SrcResetCoord
true
,
// SrcResetCoord
true
,
// DstResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
b
_grid_desc_
b
k0_n_
b
k1
,
k
_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
b_element_op
,
b_element_op
,
b_block_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
...
@@ -746,9 +824,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -746,9 +824,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
const
auto
a_block_reset_copy_step
=
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
a
_grid_desc_
a
k0_m_
a
k1
.
GetLength
(
I0
),
0
,
0
);
make_multi_index
(
-
q
_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b_block_reset_copy_step
=
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b
_grid_desc_
b
k0_n_
b
k1
.
GetLength
(
I0
),
NPerBlock
,
0
);
make_multi_index
(
-
k
_grid_desc_k0_n_k1
.
GetLength
(
I0
),
NPerBlock
,
0
);
// gridwise GEMM pipeline
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
// Only supports LoopScheduler::Default
...
@@ -757,11 +835,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -757,11 +835,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
LoopScheduler
::
Default
>
();
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a
_grid_desc_
a
k0_m_
a
k1
.
GetLength
(
I0
)
*
a
_grid_desc_
a
k0_m_
a
k1
.
GetLength
(
I2
))
/
(
q
_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q
_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
//
//
// set up Gemm
1
// set up
O / dQ
Gemm
//
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
...
@@ -811,47 +889,47 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -811,47 +889,47 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A1 matrix blockwise copy
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
//
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc
,
//
FloatGemmAcc,
DataType
,
//
DataType,
decltype
(
acc_thread_desc_k0_m_k1
),
//
decltype(acc_thread_desc_k0_m_k1),
decltype
(
a1_thread_desc_k0_m_k1
),
//
decltype(a1_thread_desc_k0_m_k1),
tensor_operation
::
element_wise
::
PassThrough
,
//
tensor_operation::element_wise::PassThrough,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
//
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence
<
1
,
0
,
2
>
,
//
Sequence<1, 0, 2>,
2
,
//
2,
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
//
n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
//
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
//
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation
,
//
BElementwiseOperation,
tensor_operation
::
element_wise
::
PassThrough
,
//
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum
::
Set
,
//
InMemoryDataOperationEnum::Set,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
//
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
//
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder
,
//
B1BlockTransferThreadClusterArrangeOrder,
DataType
,
//
DataType,
DataType
,
//
DataType,
decltype
(
v_grid_desc_n0_o_n1
),
//
decltype(v_grid_desc_n0_o_n1),
decltype
(
b1_block_desc_bk0_n_bk1
),
//
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder
,
//
B1BlockTransferSrcAccessOrder,
Sequence
<
1
,
0
,
2
>
,
//
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim
,
//
B1BlockTransferSrcVectorDim,
2
,
//
2,
B1BlockTransferSrcScalarPerVector
,
//
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1
,
//
B1BlockTransferDstScalarPerVector_BK1,
1
,
//
1,
1
,
//
1,
B1ThreadTransferSrcResetCoordinateAfterRun
,
//
B1ThreadTransferSrcResetCoordinateAfterRun,
true
,
// DstResetCoord
//
true, // DstResetCoord
NumGemmKPrefetchStage
>
(
//
NumGemmKPrefetchStage>(
v_grid_desc_n0_o_n1
,
//
v_grid_desc_n0_o_n1,
make_multi_index
(
0
,
gemm1_n_block_data_idx_on_grid
,
0
),
//
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
b1_element_op
,
//
b1_element_op,
b1_block_desc_bk0_n_bk1
,
//
b1_block_desc_bk0_n_bk1,
make_multi_index
(
0
,
0
,
0
),
//
make_multi_index(0, 0, 0),
tensor_operation
::
element_wise
::
PassThrough
{});
//
tensor_operation::element_wise::PassThrough{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
...
@@ -944,7 +1022,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -944,7 +1022,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
const
index_t
num_gemm1_k_block_outer_loop
=
const
index_t
num_gemm1_k_block_outer_loop
=
b
_grid_desc_
b
k0_n_
b
k1
.
GetLength
(
I1
)
/
NPerBlock
;
k
_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
// Initialize C
// Initialize C
...
@@ -1178,7 +1256,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1178,7 +1256,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
VGradGemmTile_N_O_M
::
GemmMPack
,
VGradGemmTile_N_O_M
::
GemmMPack
,
true
>
{};
// TranspossC
true
>
{};
// TranspossC
auto
vgrad_
acc_
thread_buf
=
vgrad_blockwise_gemm
.
GetCThreadBuffer
();
auto
vgrad_thread_buf
=
vgrad_blockwise_gemm
.
GetCThreadBuffer
();
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
// variable I1 there
...
@@ -1351,7 +1429,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1351,7 +1429,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
v_grid_desc_o0_n_o1
=
const
auto
v_grid_desc_o0_n_o1
=
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
// A matrix blockwise copy
//
dP Gemm
A matrix blockwise copy
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1382,7 +1460,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1382,7 +1460,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
//
dP Gemm
B matrix blockwise copy
auto
pgrad_gemm_tile_v_blockwise_copy
=
auto
pgrad_gemm_tile_v_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1454,6 +1532,81 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1454,6 +1532,81 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
//
// dQ
//
const
auto
k_grid_desc_n0_k_n1
=
QGradGemmTile_M_K_N
::
MakeKGridDesc_N0_K_N1
(
k_grid_desc_k0_n_k1
);
auto
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
=
QGradGemmTile_M_K_N
::
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
q_grid_desc_k0_m_k1
);
// dQ Gemm A matrix blockwise copy
auto
qgrad_gemm_tile_sgrad_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
DataType
,
decltype
(
acc_thread_desc_k0_m_k1
),
// reuse desc
decltype
(
a1_thread_desc_k0_m_k1
),
// reuse desc
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// dQ Gemm B matrix blockwise copy
auto
qgrad_gemm_tile_k_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
// reuse from V
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
// reuse from V
B1BlockTransferThreadClusterArrangeOrder
,
// reuse from V
DataType
,
DataType
,
decltype
(
k_grid_desc_n0_k_n1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
// reuse from V
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
k_grid_desc_n0_k_n1
,
make_multi_index
(
0
,
gemm1_n_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
qgrad_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
a1_thread_desc_k0_m_k1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
(
a1_thread_desc_k0_m_k1
)),
decltype
(
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
b1_block_desc_bk0_n_bk1
)),
MPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm1KPack
,
true
,
// TransposeC
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
DataType
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
// BMmaKStride
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
qgrad_thread_buf
=
qgrad_blockwise_gemm
.
GetCThreadBuffer
();
//
//
// calculate y dot ygrad
// calculate y dot ygrad
//
//
...
@@ -1528,7 +1681,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1528,7 +1681,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_dot_ygrad_thread_buf
);
y_dot_ygrad_thread_buf
);
#if
0
#if
1
if
(
hipBlockIdx_x
<
4
&&
hipThreadIdx_x
%
32
<
4
)
if
(
hipBlockIdx_x
<
4
&&
hipThreadIdx_x
%
32
<
4
)
{
{
printf
(
"bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d
\n
"
,
printf
(
"bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d
\n
"
,
...
@@ -1547,6 +1700,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1547,6 +1700,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
lse_thread_buf
);
lse_thread_buf
);
// Initialize dQ
qgrad_thread_buf
.
Clear
();
// gemm1 K loop
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
do
do
...
@@ -1559,16 +1716,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1559,16 +1716,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
continue
;
continue
;
}
}
// gemm0
// gemm0
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a
_grid_desc_
a
k0_m_
a
k1
,
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
q
_grid_desc_k0_m_k1
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_blockwise_copy
,
a
_grid_buf
,
q
_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b
_grid_desc_
b
k0_n_
b
k1
,
k
_grid_desc_k0_n_k1
,
b_block_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_blockwise_copy
,
b
_grid_buf
,
k
_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
,
...
@@ -1673,7 +1830,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1673,7 +1830,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
VGradGemmTile_N_O_M
::
Sum_M
;
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
VGradGemmTile_N_O_M
::
Sum_M
;
static_assert
(
sfc_p_m0_n0_m1_n1_m2_n2
.
GetNumOfAccess
()
==
num_vgrad_gemm_loop
,
""
);
static_assert
(
sfc_p_m0_n0_m1_n1_m2_n2
.
GetNumOfAccess
()
==
num_vgrad_gemm_loop
,
""
);
vgrad_
acc_
thread_buf
.
Clear
();
vgrad_thread_buf
.
Clear
();
// TODO ANT: single buffer prefetch pipeline
// TODO ANT: single buffer prefetch pipeline
static_for
<
0
,
num_vgrad_gemm_loop
,
1
>
{}([
&
](
auto
vgrad_gemm_loop_idx
)
{
// gemm dV
static_for
<
0
,
num_vgrad_gemm_loop
,
1
>
{}([
&
](
auto
vgrad_gemm_loop_idx
)
{
// gemm dV
...
@@ -1736,7 +1893,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1736,7 +1893,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
#endif
#endif
block_sync_lds
();
// sync before read
block_sync_lds
();
// sync before read
vgrad_blockwise_gemm
.
Run
(
p_block_buf
,
ygrad_block_buf
,
vgrad_
acc_
thread_buf
);
vgrad_blockwise_gemm
.
Run
(
p_block_buf
,
ygrad_block_buf
,
vgrad_thread_buf
);
#if 0
#if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
...
@@ -1745,10 +1902,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1745,10 +1902,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
gemm1_k_block_outer_index,
gemm1_k_block_outer_index,
vgrad_gemm_loop_idx.value,
vgrad_gemm_loop_idx.value,
hipThreadIdx_x,
hipThreadIdx_x,
vgrad_
acc_
thread_buf[I0],
vgrad_thread_buf[I0],
vgrad_
acc_
thread_buf[I1],
vgrad_thread_buf[I1],
vgrad_
acc_
thread_buf[I2],
vgrad_thread_buf[I2],
vgrad_
acc_
thread_buf[I3]);
vgrad_thread_buf[I3]);
}
}
#endif
#endif
});
// end gemm dV
});
// end gemm dV
...
@@ -1756,7 +1913,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1756,7 +1913,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// atomic_add dV
// atomic_add dV
vgrad_thread_copy_vgpr_to_global
.
Run
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_thread_copy_vgpr_to_global
.
Run
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
vgrad_
acc_
thread_buf
,
vgrad_thread_buf
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_buf
);
vgrad_grid_buf
);
...
@@ -1779,10 +1936,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1779,10 +1936,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_blockwise_gemm
,
pgrad_blockwise_gemm
,
pgrad_thread_buf
,
pgrad_thread_buf
,
num_o_block_main_loop
);
num_o_block_main_loop
);
#if
0
#if
1
if
(
hipBlockIdx_x
==
0
&&
hipThreadIdx_x
%
32
<
4
)
if
(
hipBlockIdx_x
==
0
&&
hipThreadIdx_x
%
32
<
4
)
{
{
printf("j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f\n",
printf
(
"
outer
j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f
\n
"
,
gemm1_k_block_outer_index
,
gemm1_k_block_outer_index
,
hipThreadIdx_x
,
hipThreadIdx_x
,
pgrad_thread_buf
[
I0
],
pgrad_thread_buf
[
I0
],
...
@@ -1806,13 +1963,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1806,13 +1963,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I1
];
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I1
];
// dS and P has same thread buf layout
// dS and P has same thread buf layout
sgrad_thread_buf
(
i
)
=
sgrad_thread_buf
(
i
)
=
acc_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
*
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
acc_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
});
});
#if
0
#if
1
if
(
hipBlockIdx_x
==
0
&&
hipThreadIdx_x
%
32
<
4
)
if
(
hipBlockIdx_x
==
0
&&
hipThreadIdx_x
%
32
<
4
)
{
{
printf("j loop idx %d, tid %zd, dS[0:3] = %f, %f, %f, %f\n",
printf
(
"
outer
j loop idx %d, tid %zd, dS[0:3] = %f, %f, %f, %f
\n
"
,
gemm1_k_block_outer_index
,
gemm1_k_block_outer_index
,
hipThreadIdx_x
,
hipThreadIdx_x
,
sgrad_thread_buf
[
I0
],
sgrad_thread_buf
[
I0
],
...
@@ -1822,10 +1979,90 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1822,10 +1979,90 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
#endif
#endif
// gemm dQ
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
qgrad_gemm_tile_k_blockwise_copy
.
RunRead
(
k_grid_desc_n0_k_n1
,
k_grid_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
b1_block_slice_copy_step
);
block_sync_lds
();
// wait for previous LDS read
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds dQ gemm K matrix\n");
if(hipBlockIdx_x == 0)
{
debug::print_shared(b1_block_buf.p_data_,
(index_t)b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
}
#endif
// main body
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
i
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
sgrad_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
#if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("inner j loop idx %d, tid %zd, dS downcast[0:3] = %f, %f, %f, %f\n",
i.value,
hipThreadIdx_x,
(float)a1_thread_buf[I0],
(float)a1_thread_buf[I1],
(float)a1_thread_buf[I2],
(float)a1_thread_buf[I3]);
}
#endif
qgrad_gemm_tile_k_blockwise_copy
.
RunRead
(
k_grid_desc_n0_k_n1
,
k_grid_buf
);
block_sync_lds
();
qgrad_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
qgrad_thread_buf
);
block_sync_lds
();
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
b1_block_slice_copy_step
);
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
});
}
// tail
{
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
(
num_gemm1_k_block_inner_loop
-
1
)
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
sgrad_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
block_sync_lds
();
qgrad_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
qgrad_thread_buf
);
}
}
// end gemm dQ
// move slice window
// move slice window
a_blockwise_copy
.
MoveSrcSliceWindow
(
a
_grid_desc_
a
k0_m_
a
k1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
q
_grid_desc_k0_m_k1
,
a_block_reset_copy_step
);
// rewind K
a_block_reset_copy_step
);
// rewind K
b_blockwise_copy
.
MoveSrcSliceWindow
(
b
_grid_desc_
b
k0_n_
b
k1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
k
_grid_desc_k0_n_k1
,
b_block_reset_copy_step
);
// rewind K and step N
b_block_reset_copy_step
);
// rewind K and step N
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
ygrad_block_reset_copy_step
);
// rewind M
ygrad_block_reset_copy_step
);
// rewind M
...
@@ -1841,7 +2078,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1841,7 +2078,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO ANT:
// TODO ANT:
// shuffle dQ and write
// shuffle dQ and write
#if 0
#if 1
if
(
hipBlockIdx_x
==
0
&&
hipThreadIdx_x
%
32
<
4
)
{
printf
(
"tid %zd, dQ[0:3] = %f, %f, %f, %f
\n
"
,
hipThreadIdx_x
,
qgrad_thread_buf
[
I0
],
qgrad_thread_buf
[
I1
],
qgrad_thread_buf
[
I2
],
qgrad_thread_buf
[
I3
]);
}
#endif
{
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
Gemm1NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
Gemm1NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
...
@@ -1968,7 +2217,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1968,7 +2217,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatCShuffle
,
// typename SrcData,
FloatCShuffle
,
// typename SrcData,
DataType
,
// typename DstData,
DataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype(
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock),
decltype
(
qgrad
_grid_desc_mblock_mperblock_
k
block_
k
perblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
...
@@ -1976,7 +2225,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1976,7 +2225,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock,
qgrad
_grid_desc_mblock_mperblock_
k
block_
k
perblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
c_element_op
};
...
@@ -2013,7 +2262,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -2013,7 +2262,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// each thread write its data from VGPR to LDS
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c
_thread_buf,
qgrad
_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
c_shuffle_block_buf
);
c_shuffle_block_buf
);
...
@@ -2024,20 +2273,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -2024,20 +2273,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock,
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
,
c_grid_buf);
qgrad_grid_buf
);
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds dQ shuffle loop %d\n", access_id.value);
if(hipBlockIdx_x == 1)
{
debug::print_shared(c_shuffle_block_buf.p_data_,
(index_t)c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
}
#endif
if
constexpr
(
access_id
<
num_access
-
1
)
if
constexpr
(
access_id
<
num_access
-
1
)
{
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c
_grid_desc_mblock_mperblock_
n
block_
n
perblock, c_global_step);
qgrad
_grid_desc_mblock_mperblock_
k
block_
k
perblock
,
c_global_step
);
}
}
});
});
}
}
#endif
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment