Commit cb2d4dbb authored by ltqin's avatar ltqin
Browse files

Merge branch 'attn-bwd-dropout' into attn-fwd-train-dropout

parents 989e3d10 0e7aeef5
...@@ -3,12 +3,14 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_ ...@@ -3,12 +3,14 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_train_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_train_xdl_bf16 grouped_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward_fp16 grouped_multihead_attention_forward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_forward_fp16 batched_multihead_attention_forward_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
...@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial ...@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
#include "run_batched_gemm_scale_softmax_gemm_permute_train.inc" #include "run_batched_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial ...@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
#include "run_batched_gemm_scale_softmax_gemm_permute_train.inc" #include "run_batched_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial ...@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
#include "run_grouped_gemm_scale_softmax_gemm_permute_train.inc" #include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial ...@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
#include "run_grouped_gemm_scale_softmax_gemm_permute_train.inc" #include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -16,12 +16,15 @@ struct BlockwiseDropout ...@@ -16,12 +16,15 @@ struct BlockwiseDropout
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer> template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph) __host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
return keep ? val * p_dropout_rescale : float(0); if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
}; };
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
...@@ -47,6 +50,42 @@ struct BlockwiseDropout ...@@ -47,6 +50,42 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
ushort p_dropout_16bits; ushort p_dropout_16bits;
DataType p_dropout_rescale; DataType p_dropout_rescale;
}; };
......
...@@ -50,7 +50,8 @@ template <index_t BlockSize, ...@@ -50,7 +50,8 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack,
bool TransposeC = false>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
...@@ -185,6 +186,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -185,6 +186,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
"wrong!"); "wrong!");
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
...@@ -211,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -211,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N)); make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
...@@ -303,6 +334,58 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -303,6 +334,58 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
__host__ __device__ static constexpr auto MakeCThreadTileIterator()
{
constexpr auto c_thread_lengths = conditional_expr<TransposeC>(
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths());
return SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>{}; // SnakeCurved
}
__host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D()
{
if constexpr(TransposeC)
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<5>{});
constexpr auto n3 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n4 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)),
make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
return thread_idx_to_m_n_adaptor;
}
else
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto m3 = c_thread_desc.GetLength(Number<5>{});
constexpr auto m4 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)),
make_unmerge_transform(make_tuple(n0, n1, n2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return thread_idx_to_m_n_adaptor;
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
...@@ -905,6 +988,58 @@ struct BlockwiseGemmXdlops_v2 ...@@ -905,6 +988,58 @@ struct BlockwiseGemmXdlops_v2
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
__host__ __device__ static constexpr auto MakeCThreadTileIterator()
{
constexpr auto c_thread_lengths = conditional_expr<TransposeC>(
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(),
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths());
return SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>{}; // SnakeCurved
}
__host__ __device__ static constexpr auto MakeCThreadIndexAdaptor8DTo2D()
{
if constexpr(TransposeC)
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<5>{});
constexpr auto n3 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n4 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2)),
make_unmerge_transform(make_tuple(n0, n1, n2, n3, n4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
return thread_idx_to_m_n_adaptor;
}
else
{
constexpr auto c_thread_desc = GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto m0 = c_thread_desc.GetLength(Number<0>{});
constexpr auto n0 = c_thread_desc.GetLength(Number<1>{});
constexpr auto m1 = c_thread_desc.GetLength(Number<2>{});
constexpr auto n1 = c_thread_desc.GetLength(Number<3>{});
constexpr auto m2 = c_thread_desc.GetLength(Number<4>{});
constexpr auto m3 = c_thread_desc.GetLength(Number<5>{});
constexpr auto m4 = c_thread_desc.GetLength(Number<6>{});
constexpr auto n2 = c_thread_desc.GetLength(Number<7>{});
constexpr auto thread_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(m0, m1, m2, m3, m4)),
make_unmerge_transform(make_tuple(n0, n1, n2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return thread_idx_to_m_n_adaptor;
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
......
...@@ -108,6 +108,24 @@ struct BlockwiseSoftmax ...@@ -108,6 +108,24 @@ struct BlockwiseSoftmax
}); });
} }
template <typename CThreadBuffer, typename LSEBuffer>
__host__ __device__ void RunWithPreCalcStats(CThreadBuffer& in_thread_buf,
const LSEBuffer& lse_thread_buf)
{
// calculate exp for elements using pre-calculated stats LSE (log-sum-exp)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - lse_thread_buf[iM]);
});
});
}
BufferType max_value_buf; BufferType max_value_buf;
BufferType sum_value_buf; BufferType sum_value_buf;
}; };
......
...@@ -84,7 +84,7 @@ template <index_t NumDimG, ...@@ -84,7 +84,7 @@ template <index_t NumDimG,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
{ {
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
......
...@@ -88,7 +88,7 @@ template <index_t NumDimG, ...@@ -88,7 +88,7 @@ template <index_t NumDimG,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
{ {
struct ProblemDesc struct ProblemDesc
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -47,7 +47,7 @@ __global__ void ...@@ -47,7 +47,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2( kernel_batched_multiheadattention_forward_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
...@@ -205,25 +205,25 @@ template <index_t NumDimG, ...@@ -205,25 +205,25 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermuteTrain<NumDimG, : public DeviceBatchedMultiheadAttentionForward<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
ADataType, ADataType,
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
MaskingSpec> MaskingSpec>
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
...@@ -244,7 +244,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -244,7 +244,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle; using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -382,7 +382,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -382,7 +382,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -648,7 +648,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -648,7 +648,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2< const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -958,7 +958,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -958,7 +958,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle" str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -37,7 +37,7 @@ __global__ void ...@@ -37,7 +37,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2( kernel_grouped_multiheadattention_forward_xdl_cshuffle(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -197,25 +197,25 @@ template <index_t NumDimG, ...@@ -197,25 +197,25 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
: public DeviceGroupedGemmSoftmaxGemmPermuteTrain<NumDimG, : public DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
ADataType, ADataType,
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
MaskingSpec> MaskingSpec>
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
...@@ -236,25 +236,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -236,25 +236,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle; using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermuteTrain<NumDimG, using ProblemDesc = typename DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
ADataType, ADataType,
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
MaskingSpec>::ProblemDesc; MaskingSpec>::ProblemDesc;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -392,7 +392,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -392,7 +392,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -705,16 +705,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -705,16 +705,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm, kernel_grouped_multiheadattention_forward_xdl_cshuffle<GridwiseGemm,
GemmAccDataType, GemmAccDataType,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_>; is_dropout_>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -969,7 +969,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -969,7 +969,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle" str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -95,6 +95,8 @@ struct Scale ...@@ -95,6 +95,8 @@ struct Scale
y = scale_ * x; y = scale_ * x;
}; };
__host__ __device__ void Append(float scale) { scale_ = scale_ * scale; }
float scale_; float scale_;
}; };
......
...@@ -83,7 +83,7 @@ template <typename FloatAB, ...@@ -83,7 +83,7 @@ template <typename FloatAB,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
...@@ -143,6 +143,16 @@ struct DynamicBuffer ...@@ -143,6 +143,16 @@ struct DynamicBuffer
} }
} }
__host__ __device__ void Clear()
{
static_assert(GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong! only local data share is supported");
for(index_t i = get_thread_local_1d_id(); i < element_space_size_; i += get_block_size())
{
Set(i, true, T{0});
}
}
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value,
...@@ -302,7 +312,9 @@ struct DynamicBuffer ...@@ -302,7 +312,9 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T"); "wrong! X should contain multiple T");
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); static_assert(GetAddressSpace() == AddressSpaceEnum::Global ||
GetAddressSpace() == AddressSpaceEnum::Lds,
"only support global mem or local data share");
#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT #if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing = bool constexpr use_amd_buffer_addressing =
...@@ -319,7 +331,7 @@ struct DynamicBuffer ...@@ -319,7 +331,7 @@ struct DynamicBuffer
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
if constexpr(use_amd_buffer_addressing) if constexpr(use_amd_buffer_addressing && GetAddressSpace() == AddressSpaceEnum::Global)
{ {
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment