Commit 8551dd43 authored by Anthony Chang's avatar Anthony Chang
Browse files

start with dY

start with dY
parent ecd5f7c9
......@@ -40,7 +40,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using DeviceGemmInstance = DeviceGemmInstance0;
using DeviceGemmInstance = DeviceGemmInstance1;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
......@@ -44,9 +44,14 @@ using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using DataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using DataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
......@@ -54,7 +59,6 @@ static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
#if 0
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......@@ -63,7 +67,70 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#endif
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
......@@ -306,6 +373,7 @@ int run(int argc, char* argv[])
DeviceMem vgrad_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
// TODO ANT: make sure K/V gradients are zeroed
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
......@@ -313,14 +381,17 @@ int run(int argc, char* argv[])
ygrad_device_buf.ToDevice(y_gs_ms_os.mData.data());
// TODO ANT: attention backward kernel
#if 0
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
......@@ -335,11 +406,11 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
q_element_op,
k_element_op,
s_element_op,
v_element_op,
y_element_op);
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{});
if(!gemm.IsSupportedArgument(argument))
{
......@@ -361,7 +432,6 @@ int run(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
#endif
bool pass = true;
if(do_verification)
......
......@@ -185,6 +185,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
"wrong!");
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
......
......@@ -209,7 +209,8 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
// TODO ANT: is this necessary?
// block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
const index_t idx_ksplit = block_1d_id / (M0 * N0);
block_1d_id = block_1d_id % (M0 * N0);
......
......@@ -54,7 +54,8 @@ template <typename SrcData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename SliceLengths, // TODO ANT: can we generalize to allow sub-wg slice transfer? need
// to distinguish what dimensions are spread across waves
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
......
......@@ -19,4 +19,37 @@ struct ThisThreadBlock
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
};
template <index_t ThreadPerBlock>
struct SubThreadBlock
{
static constexpr index_t kNumThread_ = ThreadPerBlock;
__device__ SubThreadBlock(int mwave, int nwave) : mwave_(mwave), nwave_(nwave) {}
__device__ static constexpr index_t GetNumOfThread() { return kNumThread_; }
template <typename Tuple2>
__device__ constexpr bool IsBelong(const Tuple2& mwave_range, const Tuple2& nwave_range)
{
// wave_range[I0] inclusive, wave_range[I1] exclusive
if(mwave_ < mwave_range[I0])
return false;
else if(mwave_ >= mwave_range[I1])
return false;
else if(nwave_ < nwave_range[I0])
return false;
else if(nwave_ >= nwave_range[I1])
return false;
else
return true;
}
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
private:
index_t mwave_, nwave_;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
};
} // namespace ck
add_test_executable(test_space_filling_curve space_filling_curve.cpp)
add_test_executable(test_threadwise_copy test_threadwise_copy.cpp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment