Unverified Commit 60f1f93a authored by guangzlu's avatar guangzlu Committed by GitHub
Browse files

Merge branch 'attn-bwd-develop-qloop-dropout' into attn-bwd-develop-qloop

parents 1a306e0d 26115ce7
...@@ -43,7 +43,7 @@ Kernel outputs: ...@@ -43,7 +43,7 @@ Kernel outputs:
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -72,13 +72,14 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -72,13 +72,14 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = F16;
using InputDataType = F16; using InputDataType = F16;
using OutputDataType = F16; using OutputDataType = F16;
using GemmDataType = F16; using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = INT32; // INT32 using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -89,7 +90,7 @@ static constexpr ck::index_t NumDimK = 1; ...@@ -89,7 +90,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4 // When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8 // When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4; // static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -189,8 +190,7 @@ using DeviceGemmInstanceBWD = ...@@ -189,8 +190,7 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
InputDataType, DataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -248,9 +248,8 @@ using DeviceGemmInstanceBWD = ...@@ -248,9 +248,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -330,8 +329,7 @@ using DeviceGemmInstanceBWD = ...@@ -330,8 +329,7 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
InputDataType, DataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -352,7 +350,7 @@ using DeviceGemmInstanceBWD = ...@@ -352,7 +350,7 @@ using DeviceGemmInstanceBWD =
1, 1,
256, 256,
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 64, // NPerBlock
64, // KPerBlock 64, // KPerBlock
64, // Gemm1NPerBlock 64, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
...@@ -362,9 +360,9 @@ using DeviceGemmInstanceBWD = ...@@ -362,9 +360,9 @@ using DeviceGemmInstanceBWD =
32, // MPerXDL 32, // MPerXDL
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 2, // NXdlPerWave
2, // Gemm1NXdlPerWave 2, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave 1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -387,11 +385,10 @@ using DeviceGemmInstanceBWD = ...@@ -387,11 +385,10 @@ using DeviceGemmInstanceBWD =
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
Deterministic>;
// using DeviceGemmInstanceBWD = // using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -723,8 +720,8 @@ int run(int argc, char* argv[]) ...@@ -723,8 +720,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; // 512 ck::index_t M = 1000; // 512
ck::index_t N = 512; // 512 ck::index_t N = 1000; // 512
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 4; // 54
...@@ -733,7 +730,7 @@ int run(int argc, char* argv[]) ...@@ -733,7 +730,7 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2; float p_drop = 0.1;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -1043,8 +1040,9 @@ int run(int argc, char* argv[]) ...@@ -1043,8 +1040,9 @@ int run(int argc, char* argv[])
YElementOp{}, YElementOp{},
p_drop, p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun qgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero(); // kgrad_device_buf.SetZero();
// vgrad_device_buf.SetZero();
float ave_time_bwd = invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, true}); float ave_time_bwd = invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, true});
// 5 GEMM ops in total: // 5 GEMM ops in total:
...@@ -1152,8 +1150,9 @@ int run(int argc, char* argv[]) ...@@ -1152,8 +1150,9 @@ int run(int argc, char* argv[])
std::ofstream fwd_file("./z_fwd_matrix_txt"); std::ofstream fwd_file("./z_fwd_matrix_txt");
fwd_file << z_fwd_gs_ms_ns << std::endl; fwd_file << z_fwd_gs_ms_ns << std::endl;
kgrad_device_buf.SetZero(); qgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); // kgrad_device_buf.SetZero();
// vgrad_device_buf.SetZero();
auto argument_bwd = gemm_bwd.MakeArgument( auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
......
...@@ -50,6 +50,41 @@ struct BlockwiseDropout ...@@ -50,6 +50,41 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, index_t element_global_1d_id)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void __host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf) ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
...@@ -86,6 +121,44 @@ struct BlockwiseDropout ...@@ -86,6 +121,44 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf,
ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, template <typename CThreadBuffer,
typename ZThreadBuffer, typename ZThreadBuffer,
bool using_sign_bit, bool using_sign_bit,
......
...@@ -85,7 +85,9 @@ __global__ void ...@@ -85,7 +85,9 @@ __global__ void
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const float p_drop, const float p_drop,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -178,6 +180,7 @@ __global__ void ...@@ -178,6 +180,7 @@ __global__ void
ph, ph,
0); 0);
} }
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -1005,6 +1008,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1005,6 +1008,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.p_drop_, arg.p_drop_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -79,7 +79,9 @@ __global__ void ...@@ -79,7 +79,9 @@ __global__ void
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
const GemmAccDataType p_dropout_rescale, const GemmAccDataType p_dropout_rescale,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -112,8 +114,8 @@ __global__ void ...@@ -112,8 +114,8 @@ __global__ void
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -131,6 +133,9 @@ __global__ void ...@@ -131,6 +133,9 @@ __global__ void
p_dropout_in_16bits, p_dropout_in_16bits,
p_dropout_rescale, p_dropout_rescale,
ph, ph,
g_idx,
MRaw,
NRaw,
i); i);
} }
} }
...@@ -141,8 +146,8 @@ __global__ void ...@@ -141,8 +146,8 @@ __global__ void
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -160,6 +165,9 @@ __global__ void ...@@ -160,6 +165,9 @@ __global__ void
p_dropout_in_16bits, p_dropout_in_16bits,
p_dropout_rescale, p_dropout_rescale,
ph, ph,
g_idx,
MRaw,
NRaw,
0); 0);
} }
#else #else
...@@ -644,6 +652,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -644,6 +652,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
{ {
is_lse_storing_ = false; is_lse_storing_ = false;
} }
// std::cout << "batch_count_: " << batch_count_ << std::endl;
} }
void Print() const void Print() const
...@@ -803,7 +813,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -803,7 +813,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg.p_dropout_in_16bits_, arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_, arg.p_dropout_rescale_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -1266,6 +1266,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1266,6 +1266,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const float p_drop, const float p_drop,
ck::philox& ph, ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw,
const index_t block_idx_n) const index_t block_idx_n)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
...@@ -1528,7 +1531,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1528,7 +1531,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1562,9 +1565,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1562,9 +1565,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockId
m0, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1577,8 +1580,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1577,8 +1580,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_n, // MBlockId make_multi_index(0, // MBlockId
0, // NBlockId block_work_idx_n, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
...@@ -1966,35 +1969,143 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1966,35 +1969,143 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped // 8d thread_desc in thread scope
static_for<0, n0, 1>{}([&](auto i) { constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// if(get_thread_global_1d_id() == 0){
// printf("Acc0TileIterator::GetNumOfAccess() is %d \n",
// Acc0TileIterator::GetNumOfAccess()); printf("n0.value is %d \n", n0.value);
// printf("id_step is %d \n", id_step);
//}
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
true, true>(
decltype(n0), s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer);
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // P_dropped
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0)); // static_for<0, n0, 1>{}([&](auto i) {
}); // blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( // decltype(z_tenor_buffer),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // true,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0)); // decltype(n0),
// decltype(i)>(s_slash_p_thread_buf,
// ph,
// global_elem_id + id_step
// * i.value,
// z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
ignore = z_grid_buf; ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph, global_elem_id);
} }
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
...@@ -2178,7 +2289,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2178,7 +2289,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm1::c_block_slice_copy_step); // step M qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm1::c_block_slice_copy_step); // step M
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow( lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(1, 0, 0, 0)); lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(1, 0, 0, 0));
y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock, y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
......
...@@ -447,6 +447,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -447,6 +447,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout_rescale, FloatGemmAcc p_dropout_rescale,
ck::philox& ph, ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw,
const index_t block_idx_m) const index_t block_idx_m)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -857,7 +860,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -857,7 +860,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -888,7 +891,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -888,7 +891,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -986,6 +989,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -986,6 +989,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("m_global is %d \n", m_global);
// printf("n_global is %d \n", n_global);
//}
if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{ {
acc_thread_buf(i) = -ck::NumericLimits<float>::Infinity(); acc_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
...@@ -1012,17 +1021,61 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1012,17 +1021,61 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if constexpr(IsDropout) // dropout if constexpr(IsDropout) // dropout
{ {
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
// if(get_thread_global_1d_id()==1){
// printf("at 1 m_global is %d \n", m_global);
// printf("at 1 n_global is %d \n", n_global);
// printf("at 1 global_elem_id is %d \n", global_elem_id);
// }
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
false, false>(
decltype(n0), acc_thread_buf, ph, global_elem_id, z_tenor_buffer);
decltype(i)>(
acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
...@@ -1030,13 +1083,28 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1030,13 +1083,28 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( // static_for<0, n0, 1>{}([&](auto i) {
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0)); // decltype(z_tenor_buffer),
}); // false,
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( // decltype(n0),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, // decltype(i)>(
make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0)); // acc_thread_buf, ph, global_elem_id + id_step * i.value,
// z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
...@@ -1046,7 +1114,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1046,7 +1114,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// ignore = z_grid_buf; // ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>( blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>(
acc_thread_buf, ph); acc_thread_buf, ph, global_elem_id);
} }
} }
......
...@@ -84,6 +84,17 @@ class philox ...@@ -84,6 +84,17 @@ class philox
out_tmp[3] = tmp_ph.w; out_tmp[3] = tmp_ph.w;
} }
__device__ void get_random_4x16(ushort* out, const unsigned long long subsequence)
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
out[0] = static_cast<ushort>(tmp_ph.x);
out[1] = static_cast<ushort>(tmp_ph.y);
out[2] = static_cast<ushort>(tmp_ph.z);
out[3] = static_cast<ushort>(tmp_ph.w);
}
private: private:
struct ull2 struct ull2
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment