Commit bcbeed99 authored by danyao12's avatar danyao12
Browse files

dropout deviceop

parent 7d9a6cc2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ZDataType,
typename AGridDesc_AK0_M_AK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_dropout(ZDataType* __restrict__ p_z_grid,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const unsigned long long seed,
const unsigned long long offset,
const index_t raw_m_padded,
const index_t raw_n_padded)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
ck::philox ph(seed, 0, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
GridwiseDropout::Run(z_matrix_ptr,
a_grid_desc_ak0_m_ak1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
block_2_ctile_map,
ph,
z_random_matrix_offset,
raw_n_padded);
#else
ignore = p_z_grid;
ignore = a_grid_desc_ak0_m_ak1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = seed;
ignore = offset;
ignore = raw_m_padded;
ignore = raw_n_padded;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename GemmDataType,
typename ZDataType,
typename GemmAccDataType,
GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock, // Gemm0NPerBlock
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave>
struct DeviceBatchedDropout : public BaseOperator
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
using DeviceOp = DeviceBatchedDropout;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
ASpec,
BSpec,
B1Spec,
CSpec>;
/*
Descriptors for inputs:
Q, K, V, Y, dY, per-row softmax stats
Descriptors for outputs:
dQ, dK, dV
*/
// Q in Gemm A position
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides),
Number<AK1>{});
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch() {}
ComputeBasePtrOfStridedBatch(const ZGridDesc_G_M_N& z_grid_desc_g_m_n)
: z_grid_desc_g_m_n_(z_grid_desc_g_m_n)
{
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private:
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
};
using GridwiseDropout = GridwiseBatchedDropout<ZDataType,
GemmDataType,
GemmAccDataType,
AGridDesc_AK0_M_AK1,
KGridDesc_N_K,
ZGridDesc_M_N,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave>;
// Argument
struct Argument : public BaseArgument
{
Argument(ZDataType* p_z_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
std::tuple<unsigned long long, unsigned long long> seeds)
: p_z_grid_{p_z_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
k_grid_desc_n_k_{
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseDropout::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}
{
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(z_grid_desc_g_m_n_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseDropout::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n_);
// Print();
m_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseDropout::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
}
void Print() const
{
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// a_grid_desc_g_m_k_.Print();
}
// pointers
ZDataType* p_z_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
ZGridDesc_M_N z_grid_desc_m_n_;
KGridDesc_N_K k_grid_desc_n_k_;
// batch offsets
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseDropout::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
// block-to-c-tile map
typename GridwiseDropout::DefaultBlock2CTileMap block_2_ctile_map_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
unsigned long long seed_;
unsigned long long offset_;
index_t m_raw_padded_;
index_t n_raw_padded_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!DeviceOp::IsSupportedArgument(arg))
{
throw std::runtime_error("wrong! unsupported argument");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_) * arg.batch_count_;
float ave_time = 0;
auto launch_kernel = [&]() {
const auto kernel = kernel_batched_dropout<
ZDataType,
DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseDropout::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename GridwiseDropout::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_z_grid_,
arg.a_grid_desc_ak0_m_ak1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_,
arg.seed_,
arg.offset_,
arg.m_raw_padded_,
arg.n_raw_padded_);
};
ave_time = launch_kernel();
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
#if DEBUG_LOG
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
return GridwiseDropout::CheckValidity();
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(ZDataType* p_z,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_z,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
seeds};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_z,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(static_cast<ZDataType*>(p_z),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
seeds);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() // override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedDropout"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -26,6 +26,7 @@ template <typename ZDataType, ...@@ -26,6 +26,7 @@ template <typename ZDataType,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t AK1Value, index_t AK1Value,
index_t BK1Value, index_t BK1Value,
index_t MPerXdl, index_t MPerXdl,
...@@ -113,21 +114,18 @@ struct GridwiseBatchedDropout ...@@ -113,21 +114,18 @@ struct GridwiseBatchedDropout
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(AK0, Number<MPerBlock>{}, AK1), make_tuple(Number<MPerBlock + I1>{} * AK1, AK1, I1));
make_tuple(Number<MPerBlock + I1>{} * AK1, AK1, I1));
} }
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(BK0, Number<NPerBlock>{}, BK1), make_tuple(Number<NPerBlock + I1>{} * BK1, BK1, I1));
make_tuple(Number<NPerBlock + I1>{} * BK1, BK1, I1));
} }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool CheckValidity()
CheckValidity()
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -141,7 +139,7 @@ struct GridwiseBatchedDropout ...@@ -141,7 +139,7 @@ struct GridwiseBatchedDropout
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const KGridDesc_N_K& k_grid_desc_n_k) MakeDefaultBlock2CTileMap(const KGridDesc_N_K& k_grid_desc_n_k)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, KPerBlock, KGridDesc_N_K>( return BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, Gemm1NPerBlock, KGridDesc_N_K>(
k_grid_desc_n_k); k_grid_desc_n_k);
} }
...@@ -151,7 +149,6 @@ struct GridwiseBatchedDropout ...@@ -151,7 +149,6 @@ struct GridwiseBatchedDropout
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype( using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>; MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
// S Gemm // S Gemm
struct Gemm0 struct Gemm0
{ {
...@@ -205,15 +202,14 @@ struct GridwiseBatchedDropout ...@@ -205,15 +202,14 @@ struct GridwiseBatchedDropout
}; };
template <typename Block2CTileMap> template <typename Block2CTileMap>
__device__ static void __device__ static void Run(ZDataType* __restrict__ p_z_grid,
Run(ZDataType* __restrict__ p_z_grid, const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3& z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, const Block2CTileMap& block_2_ctile_map,
const Block2CTileMap& block_2_ctile_map, ck::philox& ph,
ck::philox& ph, const index_t z_random_matrix_offset,
const index_t z_random_matrix_offset, const index_t raw_n_padded)
const index_t raw_n_padded)
{ {
// divide block work by [N, K] // divide block work by [N, K]
const auto block_work_idx = const auto block_work_idx =
...@@ -294,7 +290,7 @@ struct GridwiseBatchedDropout ...@@ -294,7 +290,7 @@ struct GridwiseBatchedDropout
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(num_gemm0_m_block_outer_loop - 1, // MBlockId make_multi_index(num_gemm0_m_block_outer_loop - 1, // MBlockId
block_work_idx[I0], // NBlockId block_work_idx[I0], // NBlockId
0, // MRepeat 0, // MRepeat
0, // NRepeat 0, // NRepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
...@@ -308,7 +304,7 @@ struct GridwiseBatchedDropout ...@@ -308,7 +304,7 @@ struct GridwiseBatchedDropout
// 8d thread_desc in thread scope // 8d thread_desc in thread scope
constexpr auto c_thread_lengths = constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths(); s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope // 8d block_desc in block scope
constexpr auto c_block_lengths = constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths(); s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
...@@ -324,15 +320,15 @@ struct GridwiseBatchedDropout ...@@ -324,15 +320,15 @@ struct GridwiseBatchedDropout
// works like multi-dimension static_for (static_ford), but provides both the linear // works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index // index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve< using Acc0TileIterator =
decltype(c_thread_lengths), SpaceFillingCurve<decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type, typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)), make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))), make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
...@@ -346,36 +342,29 @@ struct GridwiseBatchedDropout ...@@ -346,36 +342,29 @@ struct GridwiseBatchedDropout
// save z to global // save z to global
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_tile_id = z_random_matrix_offset + auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded + (m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile; (n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) + auto global_elem_id =
(n_global % DropoutTile) * raw_n_padded; global_tile_id + (wave_m_n_id[I0] * M4) + (n_global % DropoutTile) * raw_n_padded;
blockwise_dropout blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf), decltype(z_tensor_buffer),
decltype(z_tensor_buffer), decltype(DropoutTile),
decltype(DropoutTile), true>(
true>(s_slash_p_thread_buf, s_slash_p_thread_buf, ph, global_elem_id, z_tensor_buffer, raw_n_padded);
ph,
global_elem_id, z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_tensor_buffer, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
raw_n_padded); z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_thread_copy_vgpr_to_global.Run( z_grid_buf);
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
// move slice window // move slice window
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment