Commit d9579dc8 authored by fsx950223's avatar fsx950223
Browse files

merge updates

parents 98ccee74 36ca02f3
...@@ -5,14 +5,12 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 ...@@ -5,14 +5,12 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_multihead_attention_forward_fp16 grouped_multihead_attention_forward_fp16.cpp) add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_forward_fp16 batched_multihead_attention_forward_fp16.cpp) add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp)
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp) add_example_executable(example_batched_multihead_attention_backward_pt1 batched_multihead_attention_backward_pt1.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp) add_example_executable(example_batched_multihead_attention_backward_pt2 batched_multihead_attention_backward_pt2.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp) add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp) add_example_executable(example_grouped_multihead_attention_backward_fp16 grouped_multihead_attention_backward_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_example_executable(example_batched_multihead_attention_train_fp16 batched_multihead_attention_train_fp16.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 1
#define USING_HD32 0 #define USING_HD32 0
#include <iostream> #include <iostream>
...@@ -49,9 +49,10 @@ Kernel outputs: ...@@ -49,9 +49,10 @@ Kernel outputs:
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using BF16 = ck::bhalf_t;
using U16 = unsigned short; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
...@@ -59,7 +60,8 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -59,7 +60,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = F16; using DataType = BF16;
using GemmDataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
...@@ -101,6 +103,7 @@ using DeviceGemmInstance = ...@@ -101,6 +103,7 @@ using DeviceGemmInstance =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -169,6 +172,7 @@ using DeviceGemmInstance = ...@@ -169,6 +172,7 @@ using DeviceGemmInstance =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -340,16 +344,21 @@ int run(int argc, char* argv[]) ...@@ -340,16 +344,21 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; // 512 ck::index_t M = 1536; // 512
ck::index_t N = 512; // 512 ck::index_t N = 1536; // 512
ck::index_t K = 64; #if USING_HD32
ck::index_t O = 64; ck::index_t K = 32; // K/O<=32
ck::index_t G0 = 4; // 54 ck::index_t O = 32;
ck::index_t G1 = 6; // 16 #else
ck::index_t K = 64; // 32<K/O<=64
ck::index_t O = 64;
#endif
ck::index_t G0 = 1; // 54
ck::index_t G1 = 1; // 16
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
bool input_permute = true; // false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2; float p_drop = 0.2;
...@@ -386,6 +395,8 @@ int run(int argc, char* argv[]) ...@@ -386,6 +395,8 @@ int run(int argc, char* argv[])
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[12]);
p_drop = std::stoi(argv[13]);
} }
else else
{ {
...@@ -398,6 +409,22 @@ int run(int argc, char* argv[]) ...@@ -398,6 +409,22 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
std::cout << "M: " << M << std::endl;
std::cout << "N: " << N << std::endl;
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
std::cout << "p_drop: " << p_drop << std::endl;
std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1; const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
...@@ -747,9 +774,12 @@ int run(int argc, char* argv[]) ...@@ -747,9 +774,12 @@ int run(int argc, char* argv[])
{ {
auto idx_gmo = idx_gmn; auto idx_gmo = idx_gmn;
idx_gmo[2] = o; idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo); ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
} }
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y); self(idx_gmn) = ck::type_convert<DataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
#if PRINT_HOST #if PRINT_HOST
{ {
......
...@@ -50,9 +50,10 @@ Kernel outputs: ...@@ -50,9 +50,10 @@ Kernel outputs:
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using BF16 = ck::bhalf_t;
using U16 = unsigned short; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
...@@ -387,6 +388,8 @@ int run(int argc, char* argv[]) ...@@ -387,6 +388,8 @@ int run(int argc, char* argv[])
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[12]);
p_drop = std::stoi(argv[13]);
} }
else else
{ {
...@@ -399,6 +402,22 @@ int run(int argc, char* argv[]) ...@@ -399,6 +402,22 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
std::cout << "M: " << M << std::endl;
std::cout << "N: " << N << std::endl;
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
std::cout << "p_drop: " << p_drop << std::endl;
std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1; const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
...@@ -748,9 +767,12 @@ int run(int argc, char* argv[]) ...@@ -748,9 +767,12 @@ int run(int argc, char* argv[])
{ {
auto idx_gmo = idx_gmn; auto idx_gmo = idx_gmn;
idx_gmo[2] = o; idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo); ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
} }
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y); self(idx_gmn) = ck::type_convert<DataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
#if PRINT_HOST #if PRINT_HOST
{ {
......
...@@ -32,18 +32,21 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -32,18 +32,21 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short; using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16; using DataType = BF16;
using B0DataType = BF16; using GemmDataType = BF16;
using B1DataType = BF16; using ADataType = DataType;
using B0DataType = DataType;
using B1DataType = DataType;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = BF16; using CDataType = DataType;
using ZDataType = U16; using ZDataType = U16;
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
...@@ -81,6 +84,7 @@ using DeviceGemmInstance = ...@@ -81,6 +84,7 @@ using DeviceGemmInstance =
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -99,7 +103,7 @@ using DeviceGemmInstance = ...@@ -99,7 +103,7 @@ using DeviceGemmInstance =
TensorSpecC, TensorSpecC,
1, 1,
256, 256,
256, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
64, // Gemm1NPerBlock 64, // Gemm1NPerBlock
...@@ -109,7 +113,7 @@ using DeviceGemmInstance = ...@@ -109,7 +113,7 @@ using DeviceGemmInstance =
2, // B1K1 2, // B1K1
32, // MPerXDL 32, // MPerXDL
32, // NPerXDL 32, // NPerXDL
2, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
...@@ -139,7 +143,7 @@ using DeviceGemmInstance = ...@@ -139,7 +143,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: bf16 in, fp32 out // Ref Gemm0: DataType in, AccDataType out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType, B0DataType,
AccDataType, AccDataType,
...@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B0ElementOp, B0ElementOp,
Acc0ElementOp>; Acc0ElementOp>;
// Ref Softmax: fp32 in, bf16 out // Ref Softmax: AccDataType in, DataType out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>; ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: bf16 in, bf16 out // Ref Gemm1: DataType in, DataType out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType, B1DataType,
CDataType, CDataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ZDataType = U16;
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
AccDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, ADataType, ADataType>;
#include "run_batched_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -59,9 +59,10 @@ Kernel outputs: ...@@ -59,9 +59,10 @@ Kernel outputs:
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using BF16 = ck::bhalf_t;
using U16 = unsigned short; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
...@@ -69,7 +70,8 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -69,7 +70,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = F16; using DataType = BF16;
using GemmDataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
...@@ -108,6 +110,7 @@ using DeviceGemmInstanceFWD = ...@@ -108,6 +110,7 @@ using DeviceGemmInstanceFWD =
DataType, DataType,
DataType, DataType,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -180,6 +183,7 @@ using DeviceGemmInstanceBWD = ...@@ -180,6 +183,7 @@ using DeviceGemmInstanceBWD =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -248,6 +252,7 @@ using DeviceGemmInstanceBWD = ...@@ -248,6 +252,7 @@ using DeviceGemmInstanceBWD =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -419,8 +424,8 @@ int run(int argc, char* argv[]) ...@@ -419,8 +424,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 200; // 512 ck::index_t M = 129; // 512
ck::index_t N = 200; // 512 ck::index_t N = 129; // 512
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 64; ck::index_t O = 64;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 4; // 54
...@@ -428,8 +433,8 @@ int run(int argc, char* argv[]) ...@@ -428,8 +433,8 @@ int run(int argc, char* argv[])
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
bool input_permute = false; bool input_permute = true;
bool output_permute = false; bool output_permute = true;
float p_drop = 0.0; float p_drop = 0.0;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
...@@ -465,6 +470,8 @@ int run(int argc, char* argv[]) ...@@ -465,6 +470,8 @@ int run(int argc, char* argv[])
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[12]);
p_drop = std::stoi(argv[13]);
} }
else else
{ {
...@@ -477,6 +484,22 @@ int run(int argc, char* argv[]) ...@@ -477,6 +484,22 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
std::cout << "M: " << M << std::endl;
std::cout << "N: " << N << std::endl;
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
std::cout << "p_drop: " << p_drop << std::endl;
std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1; const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
...@@ -959,9 +982,12 @@ int run(int argc, char* argv[]) ...@@ -959,9 +982,12 @@ int run(int argc, char* argv[])
{ {
auto idx_gmo = idx_gmn; auto idx_gmo = idx_gmn;
idx_gmo[2] = o; idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo); ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
} }
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y); self(idx_gmn) = ck::type_convert<DataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
#if PRINT_HOST #if PRINT_HOST
{ {
...@@ -1058,7 +1084,7 @@ int run(int argc, char* argv[]) ...@@ -1058,7 +1084,7 @@ int run(int argc, char* argv[])
double atol = 1e-3; double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01 // when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<DataType, ck::bhalf_t>) if(std::is_same_v<DataType, ck::bhalf_t> || std::is_same_v<GemmDataType, ck::bhalf_t>)
{ {
rtol = 1e-2; rtol = 1e-2;
atol = 1e-2; atol = 1e-2;
......
...@@ -32,18 +32,21 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -32,18 +32,21 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short; using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16; using DataType = F16;
using B0DataType = BF16; using GemmDataType = F16;
using B1DataType = BF16; using ADataType = DataType;
using B0DataType = DataType;
using B1DataType = DataType;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = BF16; using CDataType = DataType;
using ZDataType = U16; using ZDataType = U16;
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
...@@ -81,6 +84,7 @@ using DeviceGemmInstance = ...@@ -81,6 +84,7 @@ using DeviceGemmInstance =
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -102,8 +106,8 @@ using DeviceGemmInstance = ...@@ -102,8 +106,8 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 64, // Gemm1NPerBlock
64, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -111,7 +115,7 @@ using DeviceGemmInstance = ...@@ -111,7 +115,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -139,7 +143,7 @@ using DeviceGemmInstance = ...@@ -139,7 +143,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: bf16 in, fp32 out // Ref Gemm0: DataType in, AccDataType out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType, B0DataType,
AccDataType, AccDataType,
...@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -148,11 +152,11 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B0ElementOp, B0ElementOp,
Acc0ElementOp>; Acc0ElementOp>;
// Ref Softmax: fp32 in, bf16 out // Ref Softmax: AccDataType in, DataType out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>; ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: bf16 in, bf16 out // Ref Gemm1: DataType in, DataType out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType, B1DataType,
CDataType, CDataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ZDataType = U16;
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
AccDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, ADataType, ADataType>;
#include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -97,7 +97,7 @@ int run(int argc, char* argv[]) ...@@ -97,7 +97,7 @@ int run(int argc, char* argv[])
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
output_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
...@@ -360,8 +360,7 @@ int run(int argc, char* argv[]) ...@@ -360,8 +360,7 @@ int run(int argc, char* argv[])
double atol = 1e-3; double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01 // when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> && if(std::is_same_v<DataType, ck::bhalf_t> || std::is_same_v<GemmDataType, ck::bhalf_t>)
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{ {
rtol = 1e-2; rtol = 1e-2;
atol = 1e-2; atol = 1e-2;
......
...@@ -10,8 +10,7 @@ int run(int argc, char* argv[]) ...@@ -10,8 +10,7 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0.1;
float p_drop = 0.2;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
...@@ -84,8 +83,8 @@ int run(int argc, char* argv[]) ...@@ -84,8 +83,8 @@ int run(int argc, char* argv[])
int M = 128 * (rand() % 8 + 1); int M = 128 * (rand() % 8 + 1);
int N = 128 * (rand() % 8 + 1); int N = 128 * (rand() % 8 + 1);
int K = 128; int K = 64;
int O = 128; int O = 64;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G1 = rand() % 5 + 1;
...@@ -117,7 +116,7 @@ int run(int argc, char* argv[]) ...@@ -117,7 +116,7 @@ int run(int argc, char* argv[])
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
output_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
...@@ -427,8 +426,7 @@ int run(int argc, char* argv[]) ...@@ -427,8 +426,7 @@ int run(int argc, char* argv[])
double atol = 1e-3; double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01 // when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> && if(std::is_same_v<DataType, ck::bhalf_t> || std::is_same_v<GemmDataType, ck::bhalf_t>)
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{ {
rtol = 1e-2; rtol = 1e-2;
atol = 1e-2; atol = 1e-2;
......
...@@ -118,7 +118,7 @@ ...@@ -118,7 +118,7 @@
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif #endif
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 0 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
// experimental feature: in-regsiter sub-dword transpose // experimental feature: in-regsiter sub-dword transpose
......
...@@ -173,6 +173,7 @@ template <index_t NumDimG, ...@@ -173,6 +173,7 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename DataType,
typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
...@@ -598,9 +599,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -598,9 +599,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
LSEDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
LSEDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
......
...@@ -158,6 +158,7 @@ template <index_t NumDimG, ...@@ -158,6 +158,7 @@ template <index_t NumDimG,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
...@@ -412,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -412,6 +413,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
......
...@@ -148,6 +148,7 @@ template <index_t NumDimG, ...@@ -148,6 +148,7 @@ template <index_t NumDimG,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
...@@ -423,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -423,6 +424,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace ck { namespace ck {
template <typename DataType, template <typename DataType,
typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatLSE, typename FloatLSE,
...@@ -121,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -121,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
...@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(q_block_desc_k0_m_k1), decltype(q_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1), decltype(k_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(v_block_desc_k0_n_k1), decltype(v_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(ygrad_block_desc_k0_m_k1), decltype(ygrad_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -506,13 +507,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -506,13 +507,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
BBlockDesc_BK0_N_BK1{}); BBlockDesc_BK0_N_BK1{});
} }
static constexpr index_t KPack = math::max( static constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output // Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
DataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -587,7 +589,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -587,7 +589,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc, FloatGemmAcc,
DataType, GemmDataType,
decltype(a_src_thread_desc_k0_m_k1), decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -610,7 +612,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -610,7 +612,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack =
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmMWave = Gemm0MWaves; static constexpr index_t GemmMWave = Gemm0MWaves;
static constexpr index_t GemmNWave = Gemm0NWaves; static constexpr index_t GemmNWave = Gemm0NWaves;
...@@ -676,8 +678,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -676,8 +678,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1(); static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1();
using BBlockwiseCopy = using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<DataType, ThreadwiseTensorSliceTransfer_v2<GemmDataType,
DataType, GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3), decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3), decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3, BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
...@@ -692,7 +694,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -692,7 +694,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
DataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
decltype(b_thread_desc_k0_n_k1), decltype(b_thread_desc_k0_n_k1),
...@@ -733,12 +735,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -733,12 +735,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl; static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMLoop = Free1_M / Sum_M; static constexpr index_t GemmMLoop = Free1_M / Sum_M;
static constexpr index_t GemmMPack = static constexpr index_t GemmMPack =
math::max(A_M1, MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(A_M1, MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t B_M3 = GemmMPack; // 8 static constexpr index_t B_M3 = GemmMPack; // 8
static constexpr index_t B_M2 = static constexpr index_t B_M2 =
XdlopsGemm<DataType, MPerXdl, NPerXdl, GemmMPack, false>{}.K0PerXdlops; // 2 XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmMPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_M1 = Sum_M / B_M2 / B_M3; // 4 static constexpr index_t B_M1 = Sum_M / B_M2 / B_M3; // 4
static constexpr index_t B_M0 = GemmMLoop; // 2 static constexpr index_t B_M0 = GemmMLoop; // 2
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_N0_M1_N1_M2_N2() __host__ __device__ static constexpr auto GetABlockSliceLengths_M0_N0_M1_N1_M2_N2()
{ {
...@@ -875,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -875,7 +877,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough> template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, GemmDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
ElementwiseOp, ElementwiseOp,
...@@ -968,8 +970,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -968,8 +970,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto b_thread_desc_m0_o_m1 = MakeBThreadDesc_M0_O_M1(); static constexpr auto b_thread_desc_m0_o_m1 = MakeBThreadDesc_M0_O_M1();
using BBlockwiseCopy = using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<DataType, ThreadwiseTensorSliceTransfer_v2<GemmDataType,
DataType, GemmDataType,
decltype(b_block_desc_o0_o1_o2_m0_m1_m2_m3), decltype(b_block_desc_o0_o1_o2_m0_m1_m2_m3),
decltype(b_thread_desc_o0_o1_o2_m0_m1_m2_m3), decltype(b_thread_desc_o0_o1_o2_m0_m1_m2_m3),
BThreadSlice_O0_O1_O2_M0_M1_M2_M3, BThreadSlice_O0_O1_O2_M0_M1_M2_M3,
...@@ -985,7 +987,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -985,7 +987,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
DataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_m0_n_m1), decltype(a_block_desc_m0_n_m1),
decltype(b_thread_desc_m0_o_m1), decltype(b_thread_desc_m0_o_m1),
...@@ -1001,7 +1003,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1001,7 +1003,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
Gemm2Params_N_O_M::GemmMPack, Gemm2Params_N_O_M::GemmMPack,
true, // TransposeC true, // TransposeC
Gemm2Params_N_O_M::GemmMPack * Gemm2Params_N_O_M::GemmMPack *
XdlopsGemm<DataType, MPerXdl, NPerXdl, Gemm2Params_N_O_M::GemmMPack, false>{} XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, Gemm2Params_N_O_M::GemmMPack, false>{}
.K0PerXdlops, .K0PerXdlops,
Gemm2Params_N_O_M::GemmMPack>; Gemm2Params_N_O_M::GemmMPack>;
...@@ -1092,7 +1094,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1092,7 +1094,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, ""); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr, using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType, FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O, ThreadSliceLength_M * ThreadSliceLength_O,
true>; true>;
...@@ -1165,7 +1167,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1165,7 +1167,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto p_slash_sgrad_block_desc_m0_n_m1 = static constexpr auto p_slash_sgrad_block_desc_m0_n_m1 =
GetA2BlockDescriptor_M0_N_M1<Gemm2Params_N_O_M>(); GetA2BlockDescriptor_M0_N_M1<Gemm2Params_N_O_M>();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{}; static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto q_block_space_size_aligned = static constexpr auto q_block_space_size_aligned =
math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(q_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
...@@ -1193,7 +1195,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1193,7 +1195,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr auto reduction_space_offset = static constexpr auto reduction_space_offset =
(ygrad_block_space_size_aligned.value + q_block_space_size_aligned.value) * (ygrad_block_space_size_aligned.value + q_block_space_size_aligned.value) *
sizeof(DataType) / sizeof(FloatGemmAcc); sizeof(GemmDataType) / sizeof(FloatGemmAcc);
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
...@@ -1206,14 +1208,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1206,14 +1208,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
{ {
const index_t k_bytes_end = const index_t k_bytes_end =
(SharedMemTrait::k_block_space_offset + SharedMemTrait::k_block_space_size_aligned) * (SharedMemTrait::k_block_space_offset + SharedMemTrait::k_block_space_size_aligned) *
sizeof(DataType); sizeof(GemmDataType);
const index_t v_bytes_end = const index_t v_bytes_end =
(SharedMemTrait::v_block_space_offset + SharedMemTrait::v_block_space_size_aligned) * (SharedMemTrait::v_block_space_offset + SharedMemTrait::v_block_space_size_aligned) *
sizeof(DataType); sizeof(GemmDataType);
const index_t p_slash_sgrad_bytes_end = const index_t p_slash_sgrad_bytes_end =
(SharedMemTrait::p_slash_sgrad_block_space_offset + (SharedMemTrait::p_slash_sgrad_block_space_offset +
SharedMemTrait::p_slash_sgrad_block_space_size_aligned) * SharedMemTrait::p_slash_sgrad_block_space_size_aligned) *
sizeof(DataType); sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc); sizeof(FloatGemmAcc);
...@@ -1263,8 +1265,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1263,8 +1265,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const float p_drop, const float p_drop,
ck::philox& ph) ck::philox& ph)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0)); __builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() * const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
...@@ -1315,19 +1317,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1315,19 +1317,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// LDS allocation for Q / K / V / dY // LDS allocation for Q / K / V / dY
auto q_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto q_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::q_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::q_block_space_offset,
GemmBlockwiseCopy::q_block_desc_k0_m_k1.GetElementSpaceSize()); GemmBlockwiseCopy::q_block_desc_k0_m_k1.GetElementSpaceSize());
auto k_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto k_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::k_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::k_block_space_offset,
GemmBlockwiseCopy::k_block_desc_k0_n_k1.GetElementSpaceSize()); GemmBlockwiseCopy::k_block_desc_k0_n_k1.GetElementSpaceSize());
auto v_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto v_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::v_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::v_block_space_offset,
GemmBlockwiseCopy::v_block_desc_k0_n_k1.GetElementSpaceSize()); GemmBlockwiseCopy::v_block_desc_k0_n_k1.GetElementSpaceSize());
auto ygrad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto ygrad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1.GetElementSpaceSize()); GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1.GetElementSpaceSize());
// Q matrix blockwise copy // Q matrix blockwise copy
...@@ -1394,10 +1396,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1394,10 +1396,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>; decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Gemm1: VGPR allocation for A and B // Gemm1: VGPR allocation for A and B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize()); Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( auto gemm1_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize()); Gemm1::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize());
// dQ: transform input and output tensor descriptors // dQ: transform input and output tensor descriptors
...@@ -1589,10 +1591,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1589,10 +1591,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// Gemm2: LDS allocation for A and B: be careful of alignment // Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::p_slash_sgrad_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::p_slash_sgrad_block_space_offset,
Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize()); Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize());
auto gemm2_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( auto gemm2_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm2::b_thread_desc_o0_o1_o2_m0_m1_m2_m3.GetElementSpaceSize()); Gemm2::b_thread_desc_o0_o1_o2_m0_m1_m2_m3.GetElementSpaceSize());
// dV: transform input and output tensor descriptors // dV: transform input and output tensor descriptors
...@@ -1722,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1722,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// performs for y // performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, DataType,
DataType, FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()), decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
...@@ -1735,8 +1737,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1735,8 +1737,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// performs for ygrad // performs for ygrad
auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, GemmDataType,
DataType, FloatGemmAcc,
decltype(YDotYGrad_M_O::ygrad_block_desc_m_o), decltype(YDotYGrad_M_O::ygrad_block_desc_m_o),
decltype(ygrad_thread_desc_m_o), decltype(ygrad_thread_desc_m_o),
decltype(ygrad_thread_desc_m_o.GetLengths()), decltype(ygrad_thread_desc_m_o.GetLengths()),
......
...@@ -908,7 +908,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -908,7 +908,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
static constexpr index_t SrcScalarPerVector = 16 / sizeof(FloatGemmAcc); static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType);
static constexpr auto ThreadClusterLength_O = static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{}; Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{}; static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace ck { namespace ck {
template <typename FloatAB, template <typename FloatAB,
typename FloatGemm,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
...@@ -126,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -126,7 +127,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
...@@ -242,10 +243,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -242,10 +243,10 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{ {
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) * SharedMemTrait::b_block_space_size_aligned) *
sizeof(FloatAB); sizeof(FloatGemm);
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatAB); sizeof(FloatGemm);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc); sizeof(FloatGemmAcc);
...@@ -273,11 +274,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -273,11 +274,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
// if(Gemm1N != K) if(Gemm1N != K)
//{ {
// std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
// return false; return false;
//} }
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{ {
...@@ -495,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -495,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatGemm,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -526,7 +527,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -526,7 +527,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatGemm,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -554,12 +555,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -554,12 +555,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// acc1[m][o] += acc[m][n] * B1[n][o] // acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check // sanity check
constexpr index_t KPack = math::max( constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1, BK1),
MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2< auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
FloatAB, FloatGemm,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -579,11 +581,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -579,11 +581,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<FloatGemm*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize()); a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset, static_cast<FloatGemm*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
...@@ -658,7 +660,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -658,7 +660,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// A1 matrix blockwise copy // A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc, FloatGemmAcc,
FloatAB, FloatGemm,
decltype(acc_thread_desc_k0_m_k1), decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1), decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -677,7 +679,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -677,7 +679,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatGemm,
decltype(b1_grid_desc_bk0_n_bk1), decltype(b1_grid_desc_bk0_n_bk1),
decltype(b1_block_desc_bk0_n_bk1), decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
...@@ -698,12 +700,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -698,12 +700,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemm>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize()); a1_thread_desc_k0_m_k1.GetElementSpaceSize());
// reuse LDS space for gemm0's b_block_buf // reuse LDS space for gemm0's b_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<FloatGemm*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
...@@ -716,11 +718,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -716,11 +718,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack = constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
FloatAB, FloatGemm,
FloatGemmAcc, FloatGemmAcc,
decltype(a1_thread_desc_k0_m_k1), decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1), decltype(b1_block_desc_bk0_n_bk1),
...@@ -736,7 +738,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -736,7 +738,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Gemm1KPack, Gemm1KPack,
true, // TransposeC true, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack * XdlopsGemm<FloatGemm, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_tuple(0, 0, 0, 0)}; // A_origin
...@@ -850,7 +852,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -850,7 +852,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -881,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -881,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1004,25 +1006,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1004,25 +1006,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
false>( false,
acc_thread_buf, ph, z_tenor_buffer); decltype(n0),
decltype(i)>(
z_thread_copy_vgpr_to_global.Run( acc_thread_buf, ph, z_tenor_buffer);
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), z_thread_copy_vgpr_to_global.Run(
z_tenor_buffer, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
// ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>( blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>(
acc_thread_buf, ph); acc_thread_buf, ph);
...@@ -1100,7 +1111,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1100,7 +1111,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// workaround compiler issue; see ck/ck.hpp // workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 && if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 && (is_same_v<FloatGemm, bhalf_t>)&&MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128) Gemm1NPerBlock == 128)
{ {
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
return amd_buffer_load_impl<scalar_t, vector_size>( return amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
...@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_store_impl<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
...@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<scalar_t, vector_size>( amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
...@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_max_impl<scalar_t, vector_size>( amd_buffer_atomic_max_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
......
...@@ -71,6 +71,141 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x) ...@@ -71,6 +71,141 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
return vy.template AsType<double2_t>()[I0]; return vy.template AsType<double2_t>()[I0];
} }
inline __host__ __device__ half2_t add_fp16x2_t(const half2_t& a, const half2_t& b)
{
half2_t rtn;
rtn[0] = a[0] + b[0];
rtn[1] = a[1] + b[1];
return rtn;
}
union U32FP162_ADDR
{
uint32_t* u32_a;
half2_t* fp162_a;
};
union U32FP162
{
uint32_t u32;
half2_t fp162;
};
template <>
__device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
{
U32FP162_ADDR dword_addr;
U32FP162 cur_v;
U32FP162 new_;
uint32_t old_v, new_v;
dword_addr.fp162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x;
}
// template <>
// __device__ half2_t atomic_add<half2_t>(half2_t* p_dst, const half2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// half2_t new_ = add_fp16x2_t(*reinterpret_cast<half2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// union U16BF16 {
// uint16_t u16;
// bhalf_t bf16;
// };
// inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b){
// U16BF16 xa {.bf16 = a};
// U16BF16 xb {.bf16 = b};
// U16BF16 xr;
// xr.u16 = xa.u16 + xb.u16;
// return xr.bf16;
// }
inline __host__ __device__ bhalf_t add_bf16_t(const bhalf_t& a, const bhalf_t& b)
{
return type_convert<bhalf_t>(type_convert<float>(a) + type_convert<float>(b));
}
inline __host__ __device__ bhalf2_t add_bf16x2_t(const bhalf2_t& a, const bhalf2_t& b)
{
bhalf2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]);
return rtn;
}
union U32BF162_ADDR
{
uint32_t* u32_a;
bhalf2_t* bf162_a;
};
union U32BF162
{
uint32_t u32;
bhalf2_t bf162;
};
template <>
__device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
{
U32BF162_ADDR dword_addr;
U32BF162 cur_v;
U32BF162 new_;
uint32_t old_v, new_v;
dword_addr.bf162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
return x;
}
// template <>
// __device__ bhalf2_t atomic_add<bhalf2_t>(bhalf2_t* p_dst, const bhalf2_t& x)
// {
// uint32_t * dword_addr = reinterpret_cast<uint32_t*>(p_dst);
// uint32_t cur_v = *dword_addr;
// uint32_t old_v, new_v;
// do {
// old_v = cur_v;
// bhalf2_t new_ = add_bf16x2_t(*reinterpret_cast<bhalf2_t*>(&cur_v), x);
// new_v = *reinterpret_cast<uint32_t*>(&new_);
// cur_v = atomicCAS(dword_addr, old_v, new_v);
// }while(cur_v != old_v);
// return x;
// }
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to // intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for // instantiate this template. The purpose is to make the implementation of atomic_max explicit for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment