"...composable_kernel.git" did not exist on "ac9e01e2cc3721be24619807adc444e1f59a9d25"
Unverified Commit 064f596c authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #878 from ROCmSoftwarePlatform/mha-dropout-zcheck

Add examples to verify the batched dropout kernel and device-operator
parents bcbeed99 e568bfdb
...@@ -2,4 +2,7 @@ add_example_executable(example_batched_multihead_attention_bias_forward_v2 batch ...@@ -2,4 +2,7 @@ add_example_executable(example_batched_multihead_attention_bias_forward_v2 batch
add_example_executable(example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp) add_example_executable(example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp)
add_example_executable(example_batched_multihead_attention_bias_backward_v2 batched_multihead_attention_bias_backward_v2.cpp) add_example_executable(example_batched_multihead_attention_bias_backward_v2 batched_multihead_attention_bias_backward_v2.cpp)
add_example_executable(example_grouped_multihead_attention_bias_backward_v2 grouped_multihead_attention_bias_backward_v2.cpp) add_example_executable(example_grouped_multihead_attention_bias_backward_v2 grouped_multihead_attention_bias_backward_v2.cpp)
\ No newline at end of file
add_example_executable(example_batched_multihead_attention_bias_forward_v2_zcheck batched_multihead_attention_bias_forward_v2_zcheck.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_dropout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/host_common_util.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using INT32 = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DataType = F16;
using GemmDataType = F16;
using ADataType = DataType;
using B0DataType = DataType;
using B1DataType = DataType;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = DataType;
using ZDataType = U16; // INT32
using LSEDataType = F32;
using Acc0BiasDataType = F16;
using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
32, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4,
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4,
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
4,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4,
MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif
using DeviceDropoutInstance =
ck::tensor_operation::device::DeviceBatchedDropout<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
GemmDataType,
ZDataType,
GemmDataType,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4>; // NXdlPerWave
#include "run_batched_multihead_attention_bias_forward_zcheck.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[])
{
using ck::host_common::dumpBufferToFile;
int init_method = 1;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 200; // 120
ck::index_t N = 200; // 1000
ck::index_t K = DIM;
ck::index_t O = DIM;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 8;
ck::index_t G1 = 4;
bool input_permute = false;
bool output_permute = true;
float p_drop = 0.1;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1)
{
// use default case
}
else if(argc == 2)
{
init_method = std::stoi(argv[1]);
}
else if(argc == 11)
{
init_method = std::stoi(argv[1]);
M = std::stoi(argv[2]);
N = std::stoi(argv[3]);
K = std::stoi(argv[4]);
O = std::stoi(argv[5]);
G0 = std::stoi(argv[6]);
G1 = std::stoi(argv[7]);
p_drop = std::stof(argv[8]);
input_permute = std::stoi(argv[9]);
output_permute = std::stoi(argv[10]);
}
else
{
printf("arg1: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg2 to 7: M, N, K, O, G0, G1\n");
printf("arg8: drop_prob \n");
printf("arg9 to 10: input / output permute\n");
exit(0);
}
float alpha = 1.f / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<ZDataType> z_gs_ms_ns_2(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms_host_result.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<ZDataType>{0});
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-1, 1});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) *
lse_gs_ms_device_result.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf_2(sizeof(ZDataType) * z_gs_ms_ns_2.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm_op = DeviceGemmInstance{};
auto gemm_invoker = gemm_op.MakeInvoker();
// run for storing z tensor
auto gemm_arg =
gemm_op.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()),
nullptr,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
d_gs_ms_ns_lengths,
d_gs_ms_ns_strides,
{},
{},
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should
// be at least the number of elements on a thread
if(!gemm_op.IsSupportedArgument(gemm_arg))
{
std::cout << gemm_op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
c_device_buf.SetZero();
lse_device_buf.SetZero();
gemm_invoker.Run(gemm_arg, StreamConfig{nullptr, false});
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
dumpBufferToFile("forward_z.dat", z_gs_ms_ns.mData.data(), z_gs_ms_ns.mData.size());
// do Dropout
auto dropout_op = DeviceDropoutInstance();
auto dropout_invoker = dropout_op.MakeInvoker();
auto dropout_arg =
dropout_op.MakeArgument(static_cast<ZDataType*>(z_device_buf_2.GetDeviceBuffer()),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
{seed, offset});
dropout_invoker.Run(dropout_arg, StreamConfig{nullptr, false});
z_device_buf_2.FromDevice(z_gs_ms_ns_2.mData.data());
dumpBufferToFile("canonic_z.dat", z_gs_ms_ns_2.mData.data(), z_gs_ms_ns_2.mData.size());
return ck::utils::check_integer_err(z_gs_ms_ns.mData, z_gs_ms_ns_2.mData, 1.0e-5);
}
...@@ -25,7 +25,8 @@ namespace ck { ...@@ -25,7 +25,8 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ZDataType, template <typename GridwiseDropout_,
typename ZDataType,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename Block2CTileMap, typename Block2CTileMap,
...@@ -60,13 +61,13 @@ __global__ void ...@@ -60,13 +61,13 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
GridwiseDropout::Run(z_matrix_ptr, GridwiseDropout_::Run(z_matrix_ptr,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
block_2_ctile_map, block_2_ctile_map,
ph, ph,
z_random_matrix_offset, z_random_matrix_offset,
raw_n_padded); raw_n_padded);
#else #else
ignore = p_z_grid; ignore = p_z_grid;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
...@@ -109,7 +110,7 @@ template <index_t NumDimG, ...@@ -109,7 +110,7 @@ template <index_t NumDimG,
index_t NPerXDL, index_t NPerXDL,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave> index_t NXdlPerWave>
struct DeviceBatchedDropout : public BaseOperator struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
...@@ -220,7 +221,7 @@ struct DeviceBatchedDropout : public BaseOperator ...@@ -220,7 +221,7 @@ struct DeviceBatchedDropout : public BaseOperator
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1]}, b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)} batch_count_{z_grid_desc_g_m_n_.GetLength(I0)}
{ {
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(z_grid_desc_g_m_n_); compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(z_grid_desc_g_m_n_);
...@@ -228,7 +229,7 @@ struct DeviceBatchedDropout : public BaseOperator ...@@ -228,7 +229,7 @@ struct DeviceBatchedDropout : public BaseOperator
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseDropout::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3( GridwiseDropout::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n_); z_grid_desc_m_n_);
// Print(); // Print();
...@@ -237,14 +238,6 @@ struct DeviceBatchedDropout : public BaseOperator ...@@ -237,14 +238,6 @@ struct DeviceBatchedDropout : public BaseOperator
n_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]); n_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
} }
void Print() const
{
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// a_grid_desc_g_m_k_.Print();
}
// pointers // pointers
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
...@@ -257,7 +250,7 @@ struct DeviceBatchedDropout : public BaseOperator ...@@ -257,7 +250,7 @@ struct DeviceBatchedDropout : public BaseOperator
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseDropout::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseDropout::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseDropout::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseDropout::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -294,6 +287,7 @@ struct DeviceBatchedDropout : public BaseOperator ...@@ -294,6 +287,7 @@ struct DeviceBatchedDropout : public BaseOperator
auto launch_kernel = [&]() { auto launch_kernel = [&]() {
const auto kernel = kernel_batched_dropout< const auto kernel = kernel_batched_dropout<
GridwiseDropout,
ZDataType, ZDataType,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseDropout::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, typename GridwiseDropout::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
...@@ -307,7 +301,7 @@ struct DeviceBatchedDropout : public BaseOperator ...@@ -307,7 +301,7 @@ struct DeviceBatchedDropout : public BaseOperator
0, 0,
arg.p_z_grid_, arg.p_z_grid_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
...@@ -336,9 +330,7 @@ struct DeviceBatchedDropout : public BaseOperator ...@@ -336,9 +330,7 @@ struct DeviceBatchedDropout : public BaseOperator
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if DEBUG_LOG (void)arg;
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
...@@ -425,8 +417,7 @@ struct DeviceBatchedDropout : public BaseOperator ...@@ -425,8 +417,7 @@ struct DeviceBatchedDropout : public BaseOperator
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", " << "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", " << "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", " << "CSpec" << getTensorSpecializationString(CSpec);
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -238,6 +238,14 @@ struct GridwiseBatchedDropout ...@@ -238,6 +238,14 @@ struct GridwiseBatchedDropout
constexpr auto m4 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); constexpr auto m4 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto n2 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); constexpr auto n2 = thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
// only used for BlockwiseDropout
constexpr auto thread_slice_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2 * m3 * m4, n0 * n1 * n2));
// only used for providing ApplyDropoutAttnBwdSaveZ
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
static_cast<unsigned short>(0.8f * 65535.f), static_cast<FloatGemmAcc>(1.0f / 0.8f)};
// //
// z vgpr copy to global // z vgpr copy to global
// //
...@@ -332,6 +340,9 @@ struct GridwiseBatchedDropout ...@@ -332,6 +340,9 @@ struct GridwiseBatchedDropout
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
// gemm0 M loop // gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1; index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
...@@ -372,5 +383,6 @@ struct GridwiseBatchedDropout ...@@ -372,5 +383,6 @@ struct GridwiseBatchedDropout
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop } while(0 < gemm0_m_block_outer_index--); // end j loop
}; };
};
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment