Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
60f1f93a
Unverified
Commit
60f1f93a
authored
May 24, 2023
by
guangzlu
Committed by
GitHub
May 24, 2023
Browse files
Merge branch 'attn-bwd-develop-qloop-dropout' into attn-bwd-develop-qloop
parents
1a306e0d
26115ce7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
361 additions
and
83 deletions
+361
-83
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+23
-24
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+73
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v3.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v3.hpp
+5
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+18
-6
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt3.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt3.hpp
+139
-28
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+92
-24
include/ck/utility/philox_rand.hpp
include/ck/utility/philox_rand.hpp
+11
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
60f1f93a
...
...
@@ -43,7 +43,7 @@ Kernel outputs:
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v
1
.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v
3
.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
...
@@ -72,13 +72,14 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
F16
;
using
InputDataType
=
F16
;
using
OutputDataType
=
F16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
INT32
;
// INT32
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
...
@@ -89,7 +90,7 @@ static constexpr ck::index_t NumDimK = 1;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
4
;
//
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
...
...
@@ -189,8 +190,7 @@ using DeviceGemmInstanceBWD =
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -248,9 +248,8 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
...
...
@@ -330,8 +329,7 @@ using DeviceGemmInstanceBWD =
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
OutputDataType
,
DataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -352,7 +350,7 @@ using DeviceGemmInstanceBWD =
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
64
,
// NPerBlock
64
,
// KPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
...
...
@@ -362,9 +360,9 @@ using DeviceGemmInstanceBWD =
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
1
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -387,11 +385,10 @@ using DeviceGemmInstanceBWD =
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...
...
@@ -723,8 +720,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
// 512
ck
::
index_t
N
=
512
;
// 512
ck
::
index_t
M
=
1000
;
// 512
ck
::
index_t
N
=
1000
;
// 512
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
// 54
...
...
@@ -733,7 +730,7 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.
2
;
float
p_drop
=
0.
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
...
...
@@ -1043,8 +1040,9 @@ int run(int argc, char* argv[])
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
qgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
// kgrad_device_buf.SetZero();
// vgrad_device_buf.SetZero();
float
ave_time_bwd
=
invoker_bwd
.
Run
(
argument_bwd
,
StreamConfig
{
nullptr
,
true
});
// 5 GEMM ops in total:
...
...
@@ -1152,8 +1150,9 @@ int run(int argc, char* argv[])
std
::
ofstream
fwd_file
(
"./z_fwd_matrix_txt"
);
fwd_file
<<
z_fwd_gs_ms_ns
<<
std
::
endl
;
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
qgrad_device_buf
.
SetZero
();
// kgrad_device_buf.SetZero();
// vgrad_device_buf.SetZero();
auto
argument_bwd
=
gemm_bwd
.
MakeArgument
(
static_cast
<
InputDataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
60f1f93a
...
...
@@ -50,6 +50,41 @@ struct BlockwiseDropout
});
}
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ZThreadBuffer
&
z_thread_buf
)
...
...
@@ -86,6 +121,44 @@ struct BlockwiseDropout
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v3.hpp
View file @
60f1f93a
...
...
@@ -85,7 +85,9 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
NRaw
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -178,6 +180,7 @@ __global__ void
ph
,
0
);
}
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
...
...
@@ -1005,6 +1008,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
60f1f93a
...
...
@@ -79,7 +79,9 @@ __global__ void
const
ushort
p_dropout_in_16bits
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
NRaw
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -112,8 +114,8 @@ __global__ void
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
p_lse_grid
==
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -131,6 +133,9 @@ __global__ void
p_dropout_in_16bits
,
p_dropout_rescale
,
ph
,
g_idx
,
MRaw
,
NRaw
,
i
);
}
}
...
...
@@ -141,8 +146,8 @@ __global__ void
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
p_lse_grid
==
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -160,6 +165,9 @@ __global__ void
p_dropout_in_16bits
,
p_dropout_rescale
,
ph
,
g_idx
,
MRaw
,
NRaw
,
0
);
}
#else
...
...
@@ -644,6 +652,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
{
is_lse_storing_
=
false
;
}
// std::cout << "batch_count_: " << batch_count_ << std::endl;
}
void
Print
()
const
...
...
@@ -803,7 +813,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt3.hpp
View file @
60f1f93a
...
...
@@ -1266,6 +1266,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
C0MatrixMask
&
c0_matrix_mask
,
const
float
p_drop
,
ck
::
philox
&
ph
,
const
index_t
g_idx
,
const
index_t
MRaw
,
const
index_t
NRaw
,
const
index_t
block_idx_n
)
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
...
...
@@ -1528,7 +1531,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
...
...
@@ -1562,9 +1565,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockI
D
I1
,
// NBlockI
d
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
...
...
@@ -1577,8 +1580,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx_n
,
// MBlockId
0
,
// NBlockId
make_multi_index
(
0
,
// MBlockId
block_work_idx_n
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
...
...
@@ -1966,35 +1969,143 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// save z to global
if
(
p_z_grid
)
{
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id
=
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// if(get_thread_global_1d_id() == 0){
// printf("Acc0TileIterator::GetNumOfAccess() is %d \n",
// Acc0TileIterator::GetNumOfAccess()); printf("n0.value is %d \n", n0.value);
// printf("id_step is %d \n", id_step);
//}
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
,
decltype
(
n0
),
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
-
n0
.
value
,
0
,
0
,
0
,
0
,
0
,
0
));
// static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
// decltype(z_tenor_buffer),
// true,
// decltype(n0),
// decltype(i)>(s_slash_p_thread_buf,
// ph,
// global_elem_id + id_step
// * i.value,
// z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
}
else
{
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id
=
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
ignore
=
z_grid_buf
;
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
);
}
block_sync_lds
();
// wait for gemm1 LDS read
...
...
@@ -2178,7 +2289,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
Gemm1
::
c_block_slice_copy_step
);
// step M
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
1
,
0
,
0
,
0
));
y_threadwise_copy
.
MoveSrcSliceWindow
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
60f1f93a
...
...
@@ -447,6 +447,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
ushort
p_dropout_in_16bits
,
FloatGemmAcc
p_dropout_rescale
,
ck
::
philox
&
ph
,
const
index_t
g_idx
,
const
index_t
MRaw
,
const
index_t
NRaw
,
const
index_t
block_idx_m
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -857,7 +860,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
...
...
@@ -888,7 +891,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
...
...
@@ -917,7 +920,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
block_sync_lds
();
}
do
{
auto
n_block_data_idx_on_grid
=
...
...
@@ -986,6 +989,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
// if(get_thread_global_1d_id()==0){
// printf("m_global is %d \n", m_global);
// printf("n_global is %d \n", n_global);
//}
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
{
acc_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
...
@@ -1012,31 +1021,90 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if
constexpr
(
IsDropout
)
// dropout
{
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id
=
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
// if(get_thread_global_1d_id()==1){
// printf("at 1 m_global is %d \n", m_global);
// printf("at 1 n_global is %d \n", n_global);
// printf("at 1 global_elem_id is %d \n", global_elem_id);
// }
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global
if
(
p_z_grid
)
{
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
false
,
decltype
(
n0
),
decltype
(
i
)>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
false
>(
acc_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
-
(
n0
.
value
),
0
,
0
,
0
,
0
,
0
,
0
));
z_grid_buf
);
// static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
// decltype(z_tenor_buffer),
// false,
// decltype(n0),
// decltype(i)>(
// acc_thread_buf, ph, global_elem_id + id_step * i.value,
// z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
...
...
@@ -1046,7 +1114,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// ignore = z_grid_buf;
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
false
>(
acc_thread_buf
,
ph
);
acc_thread_buf
,
ph
,
global_elem_id
);
}
}
...
...
include/ck/utility/philox_rand.hpp
View file @
60f1f93a
...
...
@@ -84,6 +84,17 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
__device__
void
get_random_4x16
(
ushort
*
out
,
const
unsigned
long
long
subsequence
)
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
out
[
0
]
=
static_cast
<
ushort
>
(
tmp_ph
.
x
);
out
[
1
]
=
static_cast
<
ushort
>
(
tmp_ph
.
y
);
out
[
2
]
=
static_cast
<
ushort
>
(
tmp_ph
.
z
);
out
[
3
]
=
static_cast
<
ushort
>
(
tmp_ph
.
w
);
}
private:
struct
ull2
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment