Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f82a220f
Commit
f82a220f
authored
Jun 01, 2023
by
guangzlu
Browse files
v4 pass
parent
ff88ffa4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
252 additions
and
159 deletions
+252
-159
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
..._softmax_gemm/batched_multihead_attention_backward_v4.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v4.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v4.cpp
+14
-14
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+19
-19
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
+26
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
+192
-125
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
View file @
f82a220f
...
@@ -766,7 +766,7 @@ int run(int argc, char* argv[])
...
@@ -766,7 +766,7 @@ int run(int argc, char* argv[])
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()
),
// set to nullptr
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v4.cpp
View file @
f82a220f
...
@@ -32,7 +32,7 @@ Kernel outputs:
...
@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
32
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -710,17 +710,17 @@ int run(int argc, char* argv[])
...
@@ -710,17 +710,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
10
00
;
// 512
ck
::
index_t
M
=
5
00
;
// 512
ck
::
index_t
N
=
10
00
;
// 512
ck
::
index_t
N
=
5
00
;
// 512
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G0
=
2
;
// 54
ck
::
index_t
G1
=
6
;
// 16
ck
::
index_t
G1
=
1
;
// 16
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.
0
;
float
p_drop
=
0.
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -944,7 +944,7 @@ int run(int argc, char* argv[])
...
@@ -944,7 +944,7 @@ int run(int argc, char* argv[])
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
static_cast
<
ZDataType
*>
(
z_fwd_device_buf
.
GetDeviceBuffer
()
),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
...
@@ -998,7 +998,7 @@ int run(int argc, char* argv[])
...
@@ -998,7 +998,7 @@ int run(int argc, char* argv[])
auto
argument_bwd
=
gemm_bwd
.
MakeArgument
(
auto
argument_bwd
=
gemm_bwd
.
MakeArgument
(
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
ZDataType
*>
(
z_bwd_device_buf
.
GetDeviceBuffer
()
),
// set to nullptr
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InputDataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
...
@@ -1399,20 +1399,20 @@ int run(int argc, char* argv[])
...
@@ -1399,20 +1399,20 @@ int run(int argc, char* argv[])
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
,
"error"
,
"error"
,
1e-
2
,
1e-
3
,
1e-
2
);
1e-
3
);
std
::
cout
<<
"Checking kgrad:
\n
"
;
std
::
cout
<<
"Checking kgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
kgrad_gs_ns_ks_host_result
.
mData
,
kgrad_gs_ns_ks_host_result
.
mData
,
"error"
,
"error"
,
1e-
2
,
1e-
3
,
1e-
2
);
1e-
3
);
std
::
cout
<<
"Checking vgrad:
\n
"
;
std
::
cout
<<
"Checking vgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
,
"error"
,
"error"
,
1e-
2
,
1e-
3
,
1e-
2
);
1e-
3
);
}
}
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
f82a220f
...
@@ -145,14 +145,14 @@ struct BlockwiseDropout
...
@@ -145,14 +145,14 @@ struct BlockwiseDropout
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
}
}
ushort
tmp_id
[
tmp_size
];
//
ushort tmp_id[tmp_size];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
//
for(int i = 0; i < philox_calls; i++)
{
//
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
//
for(int j = 0; j < 4; j++)
{
//
{
tmp_id
[
i
*
4
+
j
]
=
element_global_1d_id
+
i
*
8
;
//
tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
}
//
}
}
//
}
block_sync_lds
();
block_sync_lds
();
...
@@ -162,7 +162,7 @@ struct BlockwiseDropout
...
@@ -162,7 +162,7 @@ struct BlockwiseDropout
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
_id
[
tmp_index
];
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
});
});
...
@@ -208,17 +208,17 @@ struct BlockwiseDropout
...
@@ -208,17 +208,17 @@ struct BlockwiseDropout
ushort
tmp
[
tmp_size
];
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
}
}
ushort
tmp_id
[
tmp_size
];
//
ushort tmp_id[tmp_size];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
//
for(int i = 0; i < philox_calls; i++)
{
//
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
//
for(int j = 0; j < 4; j++)
{
//
{
tmp_id
[
i
*
4
+
j
]
=
element_global_1d_id
+
i
*
8
*
MRaw
;
//
tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
}
//
}
}
//
}
block_sync_lds
();
block_sync_lds
();
...
@@ -226,7 +226,7 @@ struct BlockwiseDropout
...
@@ -226,7 +226,7 @@ struct BlockwiseDropout
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
z_thread_buf
(
offset
)
=
tmp
_id
[
tmp_index
];
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
});
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
View file @
f82a220f
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -40,6 +41,7 @@ template <typename GridwiseGemm,
...
@@ -40,6 +41,7 @@ template <typename GridwiseGemm,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
typename
LSEGridDescriptor_M
,
...
@@ -73,6 +75,8 @@ __global__ void
...
@@ -73,6 +75,8 @@ __global__ void
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -138,6 +142,7 @@ __global__ void
...
@@ -138,6 +142,7 @@ __global__ void
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
@@ -173,6 +178,7 @@ __global__ void
...
@@ -173,6 +178,7 @@ __global__ void
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
@@ -828,6 +834,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -828,6 +834,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n_
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4
(
z_grid_desc_m_n_
);
// tmp z tensor for shuffle
// Tensor<ZDataType> z_tmp_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
// DeviceMem z_tmp_device_buf(sizeof(ZDataType) *
// z_tmp_gs_ms_ns.mDesc.GetElementSpaceSize());
// z_tmp_device_buf.ToDevice(z_tmp_gs_ms_ns.mData.data());
// p_z_tmp_grid_ = reinterpret_cast<ZDataType*>(z_tmp_device_buf.GetDeviceBuffer());
// Print();
// Print();
}
}
...
@@ -859,7 +877,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -859,7 +877,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// pointers
// pointers
const
InputDataType
*
p_a_grid_
;
const
InputDataType
*
p_a_grid_
;
const
InputDataType
*
p_b_grid_
;
const
InputDataType
*
p_b_grid_
;
// ZDataType* p_z_tmp_grid_;
ZDataType
*
p_z_grid_
;
ZDataType
*
p_z_grid_
;
const
InputDataType
*
p_b1_grid_
;
const
InputDataType
*
p_b1_grid_
;
const
InputDataType
*
p_c_grid_
;
const
InputDataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
LSEDataType
*
p_lse_grid_
;
...
@@ -890,6 +911,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -890,6 +911,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_
;
// block-to-c-tile map
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -952,6 +976,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -952,6 +976,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
LSEGridDesc_M
,
...
@@ -986,6 +1011,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -986,6 +1011,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
lse_grid_desc_m_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
View file @
f82a220f
...
@@ -122,6 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -122,6 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M4
=
mfma
.
num_input_blks
;
constexpr
auto
M5
=
mfma
.
group_size
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
...
@@ -132,6 +133,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -132,6 +133,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
1
,
3
,
5
,
9
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
1
,
3
,
5
,
9
>
{}));
}
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
//
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M3
=
mfma
.
num_groups_per_blk
;
// 4
constexpr
auto
M4
=
mfma
.
num_input_blks
;
// 2
constexpr
auto
M5
=
mfma
.
group_size
;
// 4
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
M3
,
M4
,
M5
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
NPerXdl
/
M5
,
M5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
,
9
>
{},
Sequence
<
1
,
3
,
5
,
8
,
10
>
{}));
}
__device__
static
auto
GetGemm0WaveIdx
()
__device__
static
auto
GetGemm0WaveIdx
()
{
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
thread_id
=
get_thread_local_1d_id
();
...
@@ -399,6 +420,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -399,6 +420,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
=
remove_cvref_t
<
decltype
(
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
ZGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
ZGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
=
remove_cvref_t
<
decltype
(
// for shuffle
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5_N4
(
ZGridDesc_M_N
{}))
>
;
// Q / K / V / dY
// Q / K / V / dY
struct
GemmBlockwiseCopy
struct
GemmBlockwiseCopy
{
{
...
@@ -1235,7 +1259,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1235,7 +1259,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_N3_M5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
...
@@ -1581,136 +1607,148 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1581,136 +1607,148 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// z vgpr copy to global
// z vgpr copy to global
//
//
// z matrix threadwise desc
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
constexpr
auto
z_
tmp_
thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
n0
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MGroupNum
m3
,
// NGroupNum
m3
,
// MInputNum
m4
,
// NInputNum
m4
,
// registerNum
n2
));
// registerNum
n2
));
// NPerXdl
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
//
I1
,
//
m0
,
//
n0
,
//
m1
,
//
n1
,
//
m2
,
// m0
m3
,
// m1
n2
,
// n0
I1
,
// m2
m4
));
// n1
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
z_
tmp_
thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
(),
true
>
true
>
z_tenor_buffer
;
z_tenor_buffer
_tmp
;
z_tenor_buffer
.
Clear
();
z_tenor_buffer
_tmp
.
Clear
();
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n
3
.
GetElementSpaceSize
(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_
n3_
m5_n
4
.
GetElementSpaceSize
(),
true
>
true
>
z_tenor_buffer
_tmp
;
z_tenor_buffer
;
z_tenor_buffer
_tmp
.
Clear
();
z_tenor_buffer
.
Clear
();
// z matrix global desc
// z matrix global desc
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
// ignore = p_z_tmp_grid;
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_grid_buf_tmp
=
auto
z_grid_buf_tmp
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
// tmp buffer for shuffle
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
// tmp buffer for shuffle
p_z_grid
,
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
.
GetElementSpaceSize
());
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_tmp_thread_copy_vgpr_to_global
=
auto
z_tmp_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
ZDataType
,
ZDataType
,
decltype
(
decltype
(
z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
n0
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MPerXdl
m3
,
// NGroupNum
m3
,
// NGroupNum
m4
,
// NInputNum
m4
,
// NInputNum
n2
>
,
n2
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
9
,
// DstVectorDim
1
,
// DstScalarPerVector
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
0
,
// MBlockId
make_multi_index
(
0
,
// MBlockId
block_work_idx_n
,
// NBlockId
block_work_idx_n
,
// NBlockId
0
,
// mrepeat
0
,
// mrepeat
0
,
// nrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_id
[
I1
],
// NWaveId
0
,
// MPerXdl
0
,
// MPerXdl
wave_m_n_id
[
I0
],
// group
wave_m_n_id
[
I0
],
// group
0
,
// NInputIndex
0
,
// NInputIndex
wave_m_n_id
[
I1
]),
wave_m_n_id
[
I1
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
z_tmp_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
auto
z_tmp_thread_copy_global_to_vgpr
=
ZDataType
,
ThreadwiseTensorSliceTransfer_v2
<
ZDataType
,
ushort
,
ushort
,
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4
),
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
Sequence
<
I1
,
I1
,
m0
,
n0
,
m1
,
n1
,
m2
,
m3
,
n2
,
I1
,
m4
>
,
Sequence
<
I1
,
I1
,
m0
,
n0
,
m1
,
n1
,
m2
,
m3
,
m4
,
n2
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
10
,
9
,
1
,
1
,
1
,
1
,
true
/* ResetCoordAfterRun */
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_n3_m5_n4
,
true
/* ResetCoordAfterRun */
>
{
make_multi_index
(
0
,
// MBlockId
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
block_work_idx_n
,
// NBlockId
make_multi_index
(
0
,
// MBlockId
0
,
// mrepeat
block_work_idx_n
,
// NBlockId
0
,
// nrepeat
0
,
// mrepeat
wave_id
[
I0
],
// MWaveId
0
,
// nrepeat
wave_id
[
I1
],
// NWaveId
wave_id
[
I0
],
// MWaveId
0
,
// MPerXdl
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I0
],
// group
0
,
// MPerXdl
int
(
wave_m_n_id
[
I1
]
/
4
),
// NInputIndex
wave_m_n_id
[
I0
],
// group
wave_m_n_id
[
I1
]
%
4
,
0
,
// NInputIndex
0
)};
wave_m_n_id
[
I1
])};
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ushort
,
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ZDataType
,
ZDataType
,
decltype
(
z_tmp_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
decltype
(
decltype
(
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
tensor_operation
::
element_wise
::
PassThrough
,
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
),
Sequence
<
I1
,
// MBlockId
tensor_operation
::
element_wise
::
PassThrough
,
I1
,
// NBlockID
Sequence
<
I1
,
// MBlockId
m0
,
// MRepeat
I1
,
// NBlockID
n0
,
// NRepeat
m0
,
// MRepeat
m1
,
// MWaveId
n0
,
// NRepeat
n1
,
// NWaveId
m1
,
// MWaveId
m2
,
// MPerXdl
n1
,
// NWaveId
m3
,
// NGroupNum
m2
,
// MPerXdl
m4
,
// NInputNum
m3
,
// NGroupNum
n2
>
,
m4
,
// NInputNum
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
n2
>
,
9
,
// DstVectorDim
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
1
,
// DstScalarPerVector
9
,
// DstVectorDim
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarPerVector
1
,
// DstScalarStrideInVector
InMemoryDataOperationEnum
::
Set
,
true
>
{
z_tmp_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
1
,
// DstScalarStrideInVector
make_multi_index
(
0
,
// MBlockId
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
block_work_idx_n
,
// NBlockId
make_multi_index
(
0
,
// MBlockId
0
,
// mrepeat
block_work_idx_n
,
// NBlockId
0
,
// nrepeat
0
,
// mrepeat
wave_id
[
I0
],
// MWaveId
0
,
// nrepeat
wave_id
[
I1
],
// NWaveId
wave_id
[
I0
],
// MWaveId
0
,
// MPerXdl
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I0
],
// group
0
,
// MPerXdl
0
,
// NInputIndex
wave_m_n_id
[
I0
],
// group
wave_m_n_id
[
I1
]),
0
,
// NInputIndex
tensor_operation
::
element_wise
::
PassThrough
{}};
wave_m_n_id
[
I1
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
//
//
// set up Y dot dY
// set up Y dot dY
...
@@ -1981,6 +2019,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1981,6 +2019,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
// auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
// auto m_local =
// block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
// auto n_local =
// block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
// auto m_global = m_local + m_block_data_idx_on_grid;
// auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
//}
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
m_local
=
auto
m_local
=
...
@@ -2021,19 +2074,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2021,19 +2074,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
{
// 8d thread_desc in thread scope
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
constexpr
auto
c_thread_lengths
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_
N2_N3_N4
().
GetLengths
();
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_
M3_M4_N2
().
GetLengths
();
// 8d block_desc in block scope
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
constexpr
auto
c_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_
N2_N3_N4
().
GetLengths
();
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_
M3_M4_N2
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
M3
=
c_block_lengths
[
I
6
];
constexpr
auto
M3
=
c_block_lengths
[
I
5
];
constexpr
auto
M4
=
c_block_lengths
[
I
5
];
constexpr
auto
M4
=
c_block_lengths
[
I
6
];
constexpr
auto
N2
=
c_block_lengths
[
I7
];
constexpr
auto
N2
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// works like multi-dimension static_for (static_ford), but provides both the linear
...
@@ -2050,12 +2103,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2050,12 +2103,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto
global_elem_id_raw
=
auto
global_elem_id_raw
=
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
...
@@ -2082,17 +2145,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2082,17 +2145,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ph
,
global_elem_id
,
z_tenor_buffer_tmp
,
MRaw
);
ph
,
global_elem_id
,
z_tenor_buffer_tmp
,
MRaw
);
z_tmp_thread_copy_vgpr_to_global
.
Run
(
z_tmp_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_
tmp_
thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer_tmp
,
z_tenor_buffer_tmp
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_
tmp_
grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf_tmp
);
z_grid_buf_tmp
);
block_sync_lds
();
z_tmp_thread_copy_global_to_vgpr
.
Run
(
z_tmp_thread_copy_global_to_vgpr
.
Run
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n
3
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_
n3_
m5_n
4
,
z_grid_buf_tmp
,
z_grid_buf_tmp
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n
3
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_
n3_
m5_n
4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
);
z_tenor_buffer
);
blockwise_dropout
.
template
ApplyDropout_v2
<
decltype
(
s_slash_p_thread_buf
),
blockwise_dropout
.
template
ApplyDropout_v2
<
decltype
(
s_slash_p_thread_buf
),
...
@@ -2100,11 +2165,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2100,11 +2165,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
true
>(
s_slash_p_thread_buf
,
true
>(
s_slash_p_thread_buf
,
z_tenor_buffer
);
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_thread_copy_vgpr_to_global
.
Run
(
z_
tmp_
thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_
tmp_
grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
z_grid_buf
);
block_sync_lds
();
//// P_dropped
//// P_dropped
// static_for<0, n0, 1>{}([&](auto i) {
// static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
// blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
...
@@ -2132,11 +2199,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2132,11 +2199,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
{
// 8d thread_desc in thread scope
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
constexpr
auto
c_thread_lengths
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_
N2_N3_N4
().
GetLengths
();
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_
M3_M4_N2
().
GetLengths
();
// 8d block_desc in block scope
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
constexpr
auto
c_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_
N2_N3_N4
().
GetLengths
();
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_
M3_M4_N2
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
...
@@ -2362,13 +2429,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -2362,13 +2429,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
qgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
qgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step M
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step M
z_tmp_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_tmp_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_
tmp_
grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
z_tmp_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
z_tmp_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n
3
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_
n3_
m5_n
4
,
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_
tmp_
grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment