Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
610e0ab0
Commit
610e0ab0
authored
Jun 14, 2023
by
guangzlu
Browse files
updated dropout for bwd pt4 & pt5
parent
8edf2a72
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
4877 additions
and
199 deletions
+4877
-199
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v5.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v5.cpp
+1423
-0
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+154
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
+47
-12
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
+40
-15
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle_v2.hpp
...e_batched_multihead_attention_forward_xdl_cshuffle_v2.hpp
+1156
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
+225
-85
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
+199
-87
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
..._batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
+1621
-0
include/ck/utility/philox_rand.hpp
include/ck/utility/philox_rand.hpp
+11
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
610e0ab0
...
@@ -14,6 +14,7 @@ add_example_executable(example_batched_multihead_attention_backward_v2 batched_m
...
@@ -14,6 +14,7 @@ add_example_executable(example_batched_multihead_attention_backward_v2 batched_m
add_example_executable
(
example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp
)
add_example_executable
(
example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp
)
add_example_executable
(
example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp
)
add_example_executable
(
example_batched_multihead_attention_train batched_multihead_attention_train.cpp
)
add_example_executable
(
example_batched_multihead_attention_train batched_multihead_attention_train.cpp
)
add_example_executable
(
example_batched_multihead_attention_train_v5 batched_multihead_attention_train_v5.cpp
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v5.cpp
0 → 100644
View file @
610e0ab0
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
610e0ab0
...
@@ -122,6 +122,160 @@ struct BlockwiseDropout
...
@@ -122,6 +122,160 @@ struct BlockwiseDropout
});
});
}
}
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropoutAttnBwd
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
index_t
MRaw
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropoutAttnBwdSaveZ
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
,
index_t
MRaw
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
// }
//}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropoutWithZ
(
CThreadBuffer
&
in_thread_buf
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
z_thread_buf
(
offset
)
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
// if(get_thread_global_1d_id()==0){
// printf("z at %d is %u \n", tmp_index, z_thread_buf(offset));
//}
});
});
}
// get raw z matrix with random number for shuffle
template
<
typename
ZThreadBuffer
>
__host__
__device__
void
GenerateZMatrixAttnFwd
(
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
// }
//}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
ushort
p_dropout_16bits
;
ushort
p_dropout_16bits
;
DataType
p_dropout_rescale
;
DataType
p_dropout_rescale
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
View file @
610e0ab0
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt
6
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt
4
.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -39,7 +39,7 @@ template <typename GridwiseGemm,
...
@@ -39,7 +39,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
typename
LSEGridDescriptor_M
,
...
@@ -71,7 +71,7 @@ __global__ void
...
@@ -71,7 +71,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
...
@@ -85,7 +85,9 @@ __global__ void
...
@@ -85,7 +85,9 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
NRaw
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -144,6 +146,9 @@ __global__ void
...
@@ -144,6 +146,9 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
,
ph
,
g_idx
,
MRaw
,
NRaw
,
i
);
i
);
}
}
}
}
...
@@ -176,6 +181,9 @@ __global__ void
...
@@ -176,6 +181,9 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
,
ph
,
g_idx
,
MRaw
,
NRaw
,
0
);
0
);
}
}
#else
#else
...
@@ -278,6 +286,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -278,6 +286,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
#if 0
// TODO: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -808,8 +826,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -808,8 +826,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
(
z_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
(
z_grid_desc_m_n_
);
// Print();
// Print();
}
}
...
@@ -869,8 +887,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -869,8 +887,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
;
// block-to-c-tile map
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -933,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -933,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
LSEGridDesc_M
,
...
@@ -967,7 +985,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -967,7 +985,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
lse_grid_desc_m_
,
...
@@ -979,11 +997,24 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -979,11 +997,24 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
#if 1
// if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
// {
// ave_time = launch_kernel(integral_constant<bool, true>{});
// }
// else
// {
// ave_time = launch_kernel(integral_constant<bool, false>{});
// }
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
#endif
return
ave_time
;
return
ave_time
;
}
}
...
@@ -1003,6 +1034,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1003,6 +1034,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if 0
arg.Print();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
{
return
false
;
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
View file @
610e0ab0
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt
7
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt
5
.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
...
@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
LSEGridDescriptor_M
,
typename
LSEGridDescriptor_M
,
...
@@ -70,8 +70,8 @@ __global__ void
...
@@ -70,8 +70,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -84,7 +84,9 @@ __global__ void
...
@@ -84,7 +84,9 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
NRaw
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -134,7 +136,7 @@ __global__ void
...
@@ -134,7 +136,7 @@ __global__ void
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
@@ -143,6 +145,9 @@ __global__ void
...
@@ -143,6 +145,9 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
,
ph
,
g_idx
,
MRaw
,
NRaw
,
i
);
i
);
}
}
}
}
...
@@ -166,7 +171,7 @@ __global__ void
...
@@ -166,7 +171,7 @@ __global__ void
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
@@ -175,6 +180,9 @@ __global__ void
...
@@ -175,6 +180,9 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
,
ph
,
g_idx
,
MRaw
,
NRaw
,
0
);
0
);
}
}
#else
#else
...
@@ -284,6 +292,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -284,6 +292,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
#if 0
// TODO: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -821,8 +839,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -821,8 +839,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
(
z_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
(
z_grid_desc_m_n_
);
// Print();
// Print();
}
}
...
@@ -882,8 +900,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -882,8 +900,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
;
// block-to-c-tile map
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -950,7 +968,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -950,7 +968,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
N3_N4_N5
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_
M4_M5_N3
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
LSEGridDesc_M
,
...
@@ -984,7 +1002,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -984,7 +1002,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
lse_grid_desc_m_
,
...
@@ -996,11 +1014,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -996,11 +1014,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
// to concern Gemm0's loop
#if 1
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
...
@@ -1009,7 +1030,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1009,7 +1030,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
#endif
return
ave_time
;
return
ave_time
;
}
}
...
@@ -1029,6 +1050,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1029,6 +1050,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if 0
arg.Print();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
{
return
false
;
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle_v2.hpp
0 → 100644
View file @
610e0ab0
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
View file @
610e0ab0
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
View file @
610e0ab0
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
0 → 100644
View file @
610e0ab0
This diff is collapsed.
Click to expand it.
include/ck/utility/philox_rand.hpp
View file @
610e0ab0
...
@@ -84,6 +84,17 @@ class philox
...
@@ -84,6 +84,17 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
}
__device__
void
get_random_4x16
(
ushort
*
out
,
const
unsigned
long
long
subsequence
)
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
out
[
0
]
=
static_cast
<
ushort
>
(
tmp_ph
.
x
);
out
[
1
]
=
static_cast
<
ushort
>
(
tmp_ph
.
y
);
out
[
2
]
=
static_cast
<
ushort
>
(
tmp_ph
.
z
);
out
[
3
]
=
static_cast
<
ushort
>
(
tmp_ph
.
w
);
}
private:
private:
struct
ull2
struct
ull2
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment