Unverified Commit 2837e6b3 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Batch gemm softmax gemm (#11)

* make it simple

* batched gemm+softmax+gemm
parent 6bc9ee05
...@@ -4,3 +4,4 @@ add_example_executable(example_gemm_gemm gemm_gemm.cpp) ...@@ -4,3 +4,4 @@ add_example_executable(example_gemm_gemm gemm_gemm.cpp)
add_example_executable(example_reduce reduce.cpp) add_example_executable(example_reduce reduce.cpp)
add_example_executable(example_softmax softmax.cpp) add_example_executable(example_softmax softmax.cpp)
add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp) add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp)
add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp)
#include <cstring>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "reference_batched_gemm.hpp"
#include "reference_batched_softmax.hpp"
#include "batched_gemm_softmax_gemm.hpp"
int main(int argc, char* argv[])
{
using QDataType = ck::half_t;
using KDataType = ck::half_t;
using VDataType = ck::half_t;
using SaccDataType = float;
using SMPLComputeDataType = float;
using PDataType = ck::half_t;
using OaccDataType = float;
using ODataType = ck::half_t;
ck::index_t Batch = 16;
ck::index_t M0 = 4096;
ck::index_t N0 = 4096;
ck::index_t K0 = 128;
ck::index_t N1 = 128;
if(argc == 6)
{
Batch = std::stoi(argv[1]);
M0 = std::stoi(argv[2]);
N0 = std::stoi(argv[3]);
K0 = std::stoi(argv[4]);
N1 = std::stoi(argv[5]);
}
std::array<ck::index_t, 3> q_lengths{Batch, M0, K0};
std::array<ck::index_t, 3> q_strides{M0 * K0, K0, 1};
std::array<ck::index_t, 3> k_lengths{Batch, N0, K0};
std::array<ck::index_t, 3> k_strides{N0 * K0, K0, 1};
std::array<ck::index_t, 3> v_lengths{Batch, N1, N0};
std::array<ck::index_t, 3> v_strides{N1 * N0, N0, 1};
std::array<ck::index_t, 3> s_lengths{Batch, M0, N0};
std::array<ck::index_t, 3> s_strides{M0 * N0, N0, 1};
std::array<ck::index_t, 3> p_lengths{Batch, M0, N0};
std::array<ck::index_t, 3> p_strides{M0 * N0, N0, 1};
std::array<ck::index_t, 3> o_lengths{Batch, M0, N1};
std::array<ck::index_t, 3> o_strides{M0 * N1, N1, 1};
// host verify
Tensor<QDataType> q_host(q_lengths, q_strides);
Tensor<KDataType> k_host(k_lengths, k_strides);
Tensor<VDataType> v_host(v_lengths, v_strides);
Tensor<SMPLComputeDataType> s_host_ref(s_lengths, s_strides);
Tensor<PDataType> p_host_ref(p_lengths, p_strides);
Tensor<ODataType> o_host_ref(o_lengths, o_strides);
Tensor<ODataType> o_host_dev(o_lengths, o_strides);
#if 0
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
#else
ck::utils::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
#endif
// reference
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host, k_host, s_host_ref);
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(s_host_ref,
p_host_ref);
reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, v_host, o_host_ref);
DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize());
DeviceMem k_buf(sizeof(KDataType) * k_host.GetElementSpaceSize());
DeviceMem v_buf(sizeof(VDataType) * v_host.GetElementSpaceSize());
DeviceMem o_buf(sizeof(ODataType) * o_host_ref.GetElementSpaceSize());
q_buf.ToDevice(q_host.mData.data());
k_buf.ToDevice(k_host.mData.data());
v_buf.ToDevice(v_host.mData.data());
constexpr ck::index_t kM0PerBlock = 128;
constexpr ck::index_t kN0PerBlock = 128;
constexpr ck::index_t kK0PerBlock = 32;
constexpr ck::index_t kN1PerBlock = 128;
constexpr ck::index_t kBlockSize = 256;
ck::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock);
std::cout << "grid size " << kGridSize << std::endl;
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize / warpSize;
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
float ave_time =
launch_kernel<kBlockSize, kBlockPerCu>(StreamConfig{nullptr, true},
BatchedGemmSoftmaxGemm<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
kBlockSize,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock>{},
kGridSize,
kBlockSize,
0,
static_cast<QDataType*>(q_buf.GetDeviceBuffer()),
static_cast<KDataType*>(k_buf.GetDeviceBuffer()),
static_cast<VDataType*>(v_buf.GetDeviceBuffer()),
static_cast<ODataType*>(o_buf.GetDeviceBuffer()),
M0,
N0,
K0,
N1,
Batch,
K0, // StrideQ
K0, // StrideK
N0, // StrideV
N1, // StrideO
M0 * K0, // BatchStrideQ
N0 * K0, // BatchStrideK
N1 * N0, // BatchStrideV
M0 * N1); // BatchStrideO
o_buf.FromDevice(o_host_dev.mData.data());
std::size_t flop =
std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0;
std::size_t num_btype =
sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 +
sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
return !ck::utils::check_err(o_host_dev, o_host_ref);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
#include "gemm_softmax_gemm_impl.hpp"
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
// O[M0, N1] = P[M0, N0] * V[N1, N0]
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType,
ck::index_t kBlockSize,
ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock>
struct BatchedGemmSoftmaxGemm
{
__device__ void operator()(const QDataType* q_ptr,
const KDataType* k_ptr,
const VDataType* v_ptr,
ODataType* o_ptr,
const ck::index_t M0,
const ck::index_t N0,
const ck::index_t K0,
const ck::index_t N1,
const ck::index_t /* Batch */,
const ck::index_t StrideQ,
const ck::index_t StrideK,
const ck::index_t StrideV,
const ck::index_t StrideO,
const ck::index_t BatchStrideQ,
const ck::index_t BatchStrideK,
const ck::index_t BatchStrideV,
const ck::index_t BatchStrideO) const
{
using namespace ck;
// divide problem
const index_t num_tile_m0 = M0 / kM0PerBlock;
const index_t num_tile_n1 = N1 / kN1PerBlock;
const index_t id_block = get_block_id();
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck::make_tuple(quotient, modulus);
};
const auto [itmp, id_tile_n] = f(id_block, num_tile_n1);
const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0);
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock);
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock);
const auto kernel_impl = GemmSoftmaxGemmImpl<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
kBlockSize,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock>{};
kernel_impl(q_ptr + iBatch * BatchStrideQ,
k_ptr + iBatch * BatchStrideK,
v_ptr + iBatch * BatchStrideV,
o_ptr + iBatch * BatchStrideO,
M0,
N0,
K0,
N1,
StrideQ,
StrideK,
StrideV,
StrideO,
iM0,
iN1);
}
};
...@@ -81,7 +81,7 @@ int main(int argc, char* argv[]) ...@@ -81,7 +81,7 @@ int main(int argc, char* argv[])
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
// reference gemm // reference gemm
reference_gemm<ADataType, ADataType, CDataType, float>(a_host, b_host, c_host_ref); reference_gemm<ADataType, ADataType, AccDataType, CDataType>(a_host, b_host, c_host_ref);
DeviceMem a_buf(sizeof(ADataType) * a_host.GetElementSpaceSize()); DeviceMem a_buf(sizeof(ADataType) * a_host.GetElementSpaceSize());
DeviceMem b_buf(sizeof(BDataType) * b_host.GetElementSpaceSize()); DeviceMem b_buf(sizeof(BDataType) * b_host.GetElementSpaceSize());
...@@ -99,6 +99,10 @@ int main(int argc, char* argv[]) ...@@ -99,6 +99,10 @@ int main(int argc, char* argv[])
std::cout << "grid size " << kGridSize << std::endl; std::cout << "grid size " << kGridSize << std::endl;
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize / warpSize;
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
const auto gemm_kernel = Gemm<ADataType, const auto gemm_kernel = Gemm<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
...@@ -114,23 +118,24 @@ int main(int argc, char* argv[]) ...@@ -114,23 +118,24 @@ int main(int argc, char* argv[])
kGemmNPerBlock, kGemmNPerBlock,
kGemmKPerBlock>{}; kGemmKPerBlock>{};
float ave_time = launch_kernel<kBlockSize, 2>(StreamConfig{nullptr, true}, float ave_time =
gemm_kernel, launch_kernel<kBlockSize, kBlockPerCu>(StreamConfig{nullptr, true},
kGridSize, gemm_kernel,
kBlockSize, kGridSize,
0, kBlockSize,
static_cast<ADataType*>(a_buf.GetDeviceBuffer()), 0,
static_cast<BDataType*>(b_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
M, static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
N, M,
K, N,
K, K,
K, K,
N, K,
AElementFunction{}, N,
BElementFunction{}, AElementFunction{},
CElementFunction{}); BElementFunction{},
CElementFunction{});
c_buf.FromDevice(c_host_dev.mData.data()); c_buf.FromDevice(c_host_dev.mData.data());
......
...@@ -20,9 +20,9 @@ int main(int argc, char* argv[]) ...@@ -20,9 +20,9 @@ int main(int argc, char* argv[])
{ {
using A0DataType = ck::half_t; using A0DataType = ck::half_t;
using B0DataType = ck::half_t; using B0DataType = ck::half_t;
using B1DataType = ck::half_t;
using Acc0DataType = float; using Acc0DataType = float;
using C0DataType = ck::half_t; using C0DataType = ck::half_t;
using B1DataType = ck::half_t;
using Acc1DataType = float; using Acc1DataType = float;
using C1DataType = ck::half_t; using C1DataType = ck::half_t;
...@@ -67,8 +67,9 @@ int main(int argc, char* argv[]) ...@@ -67,8 +67,9 @@ int main(int argc, char* argv[])
ck::utils::FillUniformDistributionIntegerValue<B1DataType>{-3.f, 3.f}(b1_host); ck::utils::FillUniformDistributionIntegerValue<B1DataType>{-3.f, 3.f}(b1_host);
// reference gemm // reference gemm
reference_gemm<A0DataType, B0DataType, C0DataType, float>(a0_host, b0_host, c0_host_ref); reference_gemm<A0DataType, B0DataType, Acc0DataType, C0DataType>(a0_host, b0_host, c0_host_ref);
reference_gemm<C0DataType, B1DataType, C1DataType, float>(c0_host_ref, b1_host, c1_host_ref); reference_gemm<C0DataType, B1DataType, Acc1DataType, C1DataType>(
c0_host_ref, b1_host, c1_host_ref);
DeviceMem a0_buf(sizeof(A0DataType) * a0_host.GetElementSpaceSize()); DeviceMem a0_buf(sizeof(A0DataType) * a0_host.GetElementSpaceSize());
DeviceMem b0_buf(sizeof(B0DataType) * b0_host.GetElementSpaceSize()); DeviceMem b0_buf(sizeof(B0DataType) * b0_host.GetElementSpaceSize());
...@@ -89,35 +90,39 @@ int main(int argc, char* argv[]) ...@@ -89,35 +90,39 @@ int main(int argc, char* argv[])
std::cout << "grid size " << kGridSize << std::endl; std::cout << "grid size " << kGridSize << std::endl;
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize / warpSize;
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
float ave_time = float ave_time =
launch_kernel<kBlockSize, 2>(StreamConfig{nullptr, true}, launch_kernel<kBlockSize, kBlockPerCu>(StreamConfig{nullptr, true},
GemmGemm<A0DataType, GemmGemm<A0DataType,
B0DataType, B0DataType,
Acc0DataType, B1DataType,
C0DataType, Acc0DataType,
B1DataType, C0DataType,
Acc1DataType, Acc1DataType,
C1DataType, C1DataType,
kBlockSize, kBlockSize,
kM0PerBlock, kM0PerBlock,
kN0PerBlock, kN0PerBlock,
kK0PerBlock, kK0PerBlock,
kN1PerBlock>{}, kN1PerBlock>{},
kGridSize, kGridSize,
kBlockSize, kBlockSize,
0, 0,
static_cast<A0DataType*>(a0_buf.GetDeviceBuffer()), static_cast<A0DataType*>(a0_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_buf.GetDeviceBuffer()),
static_cast<C1DataType*>(c1_buf.GetDeviceBuffer()), static_cast<C1DataType*>(c1_buf.GetDeviceBuffer()),
M0, M0,
N0, N0,
K0, K0,
N1, N1,
K0, // Lda0 K0, // Lda0
K0, // Ldb0 K0, // Ldb0
N0, // Ldb1 N0, // Ldb1
N1); // Ldc1 N1); // Ldc1
c1_buf.FromDevice(c1_host_dev.mData.data()); c1_buf.FromDevice(c1_host_dev.mData.data());
......
...@@ -16,12 +16,13 @@ ...@@ -16,12 +16,13 @@
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" #include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
// C1 = A0 * B0 * B1 // C0 = A0 * B0
// C1 = C0 * B1
template <typename A0DataType, template <typename A0DataType,
typename B0DataType, typename B0DataType,
typename B1DataType,
typename Acc0DataType, typename Acc0DataType,
typename C0DataType, typename C0DataType,
typename B1DataType,
typename Acc1DataType, typename Acc1DataType,
typename C1DataType, typename C1DataType,
ck::index_t kBlockSize, ck::index_t kBlockSize,
......
...@@ -19,14 +19,14 @@ ...@@ -19,14 +19,14 @@
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using A0DataType = ck::half_t; using QDataType = ck::half_t;
using B0DataType = ck::half_t; using KDataType = ck::half_t;
using Acc0DataType = float; using VDataType = ck::half_t;
using C0DataType = ck::half_t; using SaccDataType = float;
using D0DataType = ck::half_t; using SMPLComputeDataType = float;
using B1DataType = ck::half_t; using PDataType = ck::half_t;
using Acc1DataType = float; using OaccDataType = float;
using C1DataType = ck::half_t; using ODataType = ck::half_t;
ck::index_t M0 = 13312; ck::index_t M0 = 13312;
ck::index_t N0 = 4096; ck::index_t N0 = 4096;
...@@ -41,56 +41,57 @@ int main(int argc, char* argv[]) ...@@ -41,56 +41,57 @@ int main(int argc, char* argv[])
N1 = std::stoi(argv[4]); N1 = std::stoi(argv[4]);
} }
std::array<ck::index_t, 2> a0_lengths{M0, K0}; std::array<ck::index_t, 2> q_lengths{M0, K0};
std::array<ck::index_t, 2> a0_strides{K0, 1}; std::array<ck::index_t, 2> q_strides{K0, 1};
std::array<ck::index_t, 2> b0_lengths{N0, K0}; std::array<ck::index_t, 2> k_lengths{N0, K0};
std::array<ck::index_t, 2> b0_strides{K0, 1}; std::array<ck::index_t, 2> k_strides{K0, 1};
std::array<ck::index_t, 2> c0_lengths{M0, N0}; std::array<ck::index_t, 2> v_lengths{N1, N0};
std::array<ck::index_t, 2> c0_strides{N0, 1}; std::array<ck::index_t, 2> v_strides{N0, 1};
std::array<ck::index_t, 2> d0_lengths{M0, N0}; std::array<ck::index_t, 2> s_lengths{M0, N0};
std::array<ck::index_t, 2> d0_strides{N0, 1}; std::array<ck::index_t, 2> s_strides{N0, 1};
std::array<ck::index_t, 2> b1_lengths{N1, N0}; std::array<ck::index_t, 2> p_lengths{M0, N0};
std::array<ck::index_t, 2> b1_strides{N0, 1}; std::array<ck::index_t, 2> p_strides{N0, 1};
std::array<ck::index_t, 2> c1_lengths{M0, N1}; std::array<ck::index_t, 2> o_lengths{M0, N1};
std::array<ck::index_t, 2> c1_strides{N1, 1}; std::array<ck::index_t, 2> o_strides{N1, 1};
// host verify // host verify
Tensor<A0DataType> a0_host(a0_lengths, a0_strides); Tensor<QDataType> q_host(q_lengths, q_strides);
Tensor<B0DataType> b0_host(b0_lengths, b0_strides); Tensor<KDataType> k_host(k_lengths, k_strides);
Tensor<C0DataType> c0_host_ref(c0_lengths, c0_strides); Tensor<VDataType> v_host(v_lengths, v_strides);
Tensor<D0DataType> d0_host_ref(d0_lengths, d0_strides); Tensor<SMPLComputeDataType> s_host_ref(s_lengths, s_strides);
Tensor<B1DataType> b1_host(b1_lengths, b1_strides); Tensor<PDataType> p_host_ref(p_lengths, p_strides);
Tensor<C1DataType> c1_host_ref(c1_lengths, c1_strides); Tensor<ODataType> o_host_ref(o_lengths, o_strides);
Tensor<C1DataType> c1_host_dev(c1_lengths, c1_strides); Tensor<ODataType> o_host_dev(o_lengths, o_strides);
#if 1 #if 0
ck::utils::FillUniformDistributionIntegerValue<A0DataType>{-3.f, 3.f}(a0_host); ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<B0DataType>{-3.f, 3.f}(b0_host); ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<B1DataType>{-3.f, 3.f}(b1_host); ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
#else #else
ck::utils::FillUniformDistribution<A0DataType>{-3.f, 3.f}(a0_host); ck::utils::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistribution<B0DataType>{-3.f, 3.f}(b0_host); ck::utils::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistribution<B1DataType>{-3.f, 3.f}(b1_host); ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
#endif #endif
// reference // reference
reference_gemm<A0DataType, B0DataType, C0DataType, float>(a0_host, b0_host, c0_host_ref); reference_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
reference_softmax<C0DataType, float, D0DataType>(c0_host_ref, d0_host_ref); q_host, k_host, s_host_ref);
reference_gemm<D0DataType, B1DataType, C1DataType, float>(d0_host_ref, b1_host, c1_host_ref); reference_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(s_host_ref, p_host_ref);
reference_gemm<PDataType, VDataType, OaccDataType, ODataType>(p_host_ref, v_host, o_host_ref);
DeviceMem a0_buf(sizeof(A0DataType) * a0_host.GetElementSpaceSize()); DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize());
DeviceMem b0_buf(sizeof(B0DataType) * b0_host.GetElementSpaceSize()); DeviceMem k_buf(sizeof(KDataType) * k_host.GetElementSpaceSize());
DeviceMem b1_buf(sizeof(B1DataType) * b1_host.GetElementSpaceSize()); DeviceMem v_buf(sizeof(VDataType) * v_host.GetElementSpaceSize());
DeviceMem c1_buf(sizeof(C1DataType) * c1_host_ref.GetElementSpaceSize()); DeviceMem o_buf(sizeof(ODataType) * o_host_ref.GetElementSpaceSize());
a0_buf.ToDevice(a0_host.mData.data()); q_buf.ToDevice(q_host.mData.data());
b0_buf.ToDevice(b0_host.mData.data()); k_buf.ToDevice(k_host.mData.data());
b1_buf.ToDevice(b1_host.mData.data()); v_buf.ToDevice(v_host.mData.data());
constexpr ck::index_t kM0PerBlock = 128; constexpr ck::index_t kM0PerBlock = 128;
constexpr ck::index_t kN0PerBlock = 128; constexpr ck::index_t kN0PerBlock = 128;
...@@ -102,41 +103,46 @@ int main(int argc, char* argv[]) ...@@ -102,41 +103,46 @@ int main(int argc, char* argv[])
std::cout << "grid size " << kGridSize << std::endl; std::cout << "grid size " << kGridSize << std::endl;
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize / warpSize;
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
float ave_time = float ave_time =
launch_kernel<kBlockSize, 2>(StreamConfig{nullptr, true}, launch_kernel<kBlockSize, kBlockPerCu>(StreamConfig{nullptr, true},
GemmSoftmaxGemm<A0DataType, GemmSoftmaxGemm<QDataType,
B0DataType, KDataType,
Acc0DataType, VDataType,
C0DataType, SaccDataType,
B1DataType, SMPLComputeDataType,
Acc1DataType, PDataType,
C1DataType, OaccDataType,
kBlockSize, ODataType,
kM0PerBlock, kBlockSize,
kN0PerBlock, kM0PerBlock,
kK0PerBlock, kN0PerBlock,
kN1PerBlock>{}, kK0PerBlock,
kGridSize, kN1PerBlock>{},
kBlockSize, kGridSize,
0, kBlockSize,
static_cast<A0DataType*>(a0_buf.GetDeviceBuffer()), 0,
static_cast<B0DataType*>(b0_buf.GetDeviceBuffer()), static_cast<QDataType*>(q_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_buf.GetDeviceBuffer()), static_cast<KDataType*>(k_buf.GetDeviceBuffer()),
static_cast<C1DataType*>(c1_buf.GetDeviceBuffer()), static_cast<VDataType*>(v_buf.GetDeviceBuffer()),
M0, static_cast<ODataType*>(o_buf.GetDeviceBuffer()),
N0, M0,
K0, N0,
N1, K0,
K0, // Lda0 N1,
K0, // Ldb0 K0, // StrideQ
N0, // Ldb1 K0, // StrideK
N1); // Ldc1 N0, // StrideV
N1); // StrideO
c1_buf.FromDevice(c1_host_dev.mData.data());
o_buf.FromDevice(o_host_dev.mData.data());
std::size_t flop = std::size_t(2) * M0 * N0 * K0 + std::size_t(2) * M0 * N1 * N0; std::size_t flop = std::size_t(2) * M0 * N0 * K0 + std::size_t(2) * M0 * N1 * N0;
std::size_t num_btype = sizeof(A0DataType) * M0 * K0 + sizeof(B0DataType) * N0 * K0 + std::size_t num_btype = sizeof(QDataType) * M0 * K0 + sizeof(KDataType) * N0 * K0 +
sizeof(B1DataType) * N1 * N0 + sizeof(C1DataType) * M0 * N1; sizeof(VDataType) * N1 * N0 + sizeof(ODataType) * M0 * N1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -145,8 +151,5 @@ int main(int argc, char* argv[]) ...@@ -145,8 +151,5 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
// LogRangeAsType<float>(std::cout << "C1 dev: ", c1_host_dev.mData, ", ", 16, 20) << std::endl; return !ck::utils::check_err(o_host_dev, o_host_ref);
// LogRangeAsType<float>(std::cout << "C1 ref: ", c1_host_ref.mData, ", ", 16, 20) << std::endl;
return !ck::utils::check_err(c1_host_dev, c1_host_ref);
} }
...@@ -17,15 +17,19 @@ ...@@ -17,15 +17,19 @@
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp" #include "ck/tile_program/block_tile/block_reduce.hpp"
// C0 = A0 * B0 #include "gemm_softmax_gemm_impl.hpp"
// C1 = softmax(C0) * B1
template <typename A0DataType, // S[M0, N0] = Q[M0, K0] * K[N0, K0]
typename B0DataType, // P[M0, N0] = Softmax(S[M0, N0])
typename Acc0DataType, // O[M0, N1] = P[M0, N0] * V[N1, N0]
typename C0DataType, template <typename QDataType,
typename B1DataType, typename KDataType,
typename Acc1DataType, typename VDataType,
typename C1DataType, typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType,
ck::index_t kBlockSize, ck::index_t kBlockSize,
ck::index_t kM0PerBlock, ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock, ck::index_t kN0PerBlock,
...@@ -33,138 +37,21 @@ template <typename A0DataType, ...@@ -33,138 +37,21 @@ template <typename A0DataType,
ck::index_t kN1PerBlock> ck::index_t kN1PerBlock>
struct GemmSoftmaxGemm struct GemmSoftmaxGemm
{ {
// block gemm0 pipeline __device__ void operator()(const QDataType* q_ptr,
using BlockGemm0Pipeline = ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2< const KDataType* k_ptr,
ck::tile_program::block::BlockGemmPipelineProblem< const VDataType* v_ptr,
A0DataType, ODataType* o_ptr,
B0DataType, const ck::index_t M0,
Acc0DataType, const ck::index_t N0,
kBlockSize, const ck::index_t K0,
ck::tile_program::TileGemmShape<kM0PerBlock, kN0PerBlock, kK0PerBlock>>, const ck::index_t N1,
ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>; const ck::index_t StrideQ,
const ck::index_t StrideK,
// block gemm1 const ck::index_t StrideV,
using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1< const ck::index_t StrideO) const
ck::tile_program::block::BlockGemmARegBSmemCRegV1Problem<
C0DataType,
B1DataType,
Acc1DataType,
kBlockSize,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kN0PerBlock>>,
ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>;
#if 0
// 2d
__device__ static constexpr auto MakeB1LdsBlockDescriptor()
{ {
using namespace ck; using namespace ck;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_block_desc;
}
#else
// fake XOR
__device__ static constexpr auto MakeB1LdsBlockDescriptor()
{
using namespace ck;
using BDataType = B1DataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(kNPerBlock / 2, 2, kKPerBlock), Number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_block_desc_d1_d2_d3,
make_tuple(make_xor_transform(make_tuple(kNPerBlock / 2, kKPerBlock), kK1),
make_pass_through_transform(2)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(kNPerBlock / 2, 2)),
make_pass_through_transform(kKPerBlock)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_block_desc_n_k;
}
#endif
__device__ static constexpr auto MakeB1DramTileDistribution()
{
using namespace ck;
using namespace ck::tile_program;
using BDataType = B1DataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
}
__device__ static constexpr ck::index_t GetStaticLdsSize()
{
using namespace ck;
return math::max(BlockGemm0Pipeline::GetStaticLdsSize(),
static_cast<index_t>(MakeB1LdsBlockDescriptor().GetElementSpaceSize() *
sizeof(B1DataType)));
}
__device__ void operator()(const A0DataType* p_a0,
const B0DataType* p_b0,
const B1DataType* p_b1,
C1DataType* p_c1,
ck::index_t M0,
ck::index_t N0,
ck::index_t K0,
ck::index_t N1,
ck::index_t Lda0,
ck::index_t Ldb0,
ck::index_t Ldb1,
ck::index_t Ldc1)
{
using namespace ck;
using namespace ck::tile_program;
using namespace ck::tile_program::block;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// FIXME: assume layout A0[M0, K0], B0[N0, K0], B1[N1, N0], C1[M0, N1]
const auto a0_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_a0, make_tuple(M0, K0), make_tuple(Lda0, 1), Number<32>{}, Number<1>{});
const auto b0_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_b0, make_tuple(N0, K0), make_tuple(Ldb0, 1), Number<32>{}, Number<1>{});
const auto b1_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_b1, make_tuple(N1, N0), make_tuple(Ldb1, 1), Number<32>{}, Number<1>{});
// divide problem // divide problem
const auto num_tile_n1 = N1 / kN1PerBlock; const auto num_tile_n1 = N1 / kN1PerBlock;
...@@ -176,215 +63,33 @@ struct GemmSoftmaxGemm ...@@ -176,215 +63,33 @@ struct GemmSoftmaxGemm
const auto iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); const auto iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock);
const auto iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); const auto iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock);
__shared__ char p_smem_char[GetStaticLdsSize()]; const auto kernel_impl = GemmSoftmaxGemmImpl<QDataType,
KDataType,
// A0 DRAM block window VDataType,
auto a0_dram_block_window = make_tile_window( SaccDataType,
a0_dram_grid, make_tuple(Number<kM0PerBlock>{}, Number<kK0PerBlock>{}), {iM0, 0}); SMPLComputeDataType,
PDataType,
// B0 DRAM block window OaccDataType,
auto b0_dram_block_window = make_tile_window( ODataType,
b0_dram_grid, make_tuple(Number<kN0PerBlock>{}, Number<kK0PerBlock>{}), {0, 0}); kBlockSize,
kM0PerBlock,
// Block GEMM0 pipeline kN0PerBlock,
constexpr auto block_gemm0_pipeline = BlockGemm0Pipeline{}; kK0PerBlock,
kN1PerBlock>{};
// B1 DRAM window
auto b1_dram_block_window = kernel_impl(q_ptr,
make_tile_window(b1_dram_grid, k_ptr,
make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), v_ptr,
{iN1, 0}, o_ptr,
MakeB1DramTileDistribution()); M0,
N0,
// B1 LDS tensor view: occupies the same LDS allocation as block_gemm0_pipeline K0,
auto b1_lds_block = make_tensor_view<AddressSpaceEnum::Lds>( N1,
reinterpret_cast<B1DataType*>(p_smem_char), MakeB1LdsBlockDescriptor()); StrideQ,
StrideK,
auto b1_lds_block_window = make_tile_window( StrideV,
b1_lds_block, make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), {0, 0}); StrideO,
iM0,
// Bock GEMM1 iN1);
constexpr auto block_gemm1 = BlockGemm1{};
// Acc0 tile
using Acc0BlockTileType =
decltype(block_gemm0_pipeline(a0_dram_block_window, b0_dram_block_window, 0, nullptr));
// Acc1 tile
auto acc1_block_tile = decltype(block_gemm1(
tile_elementwise_in(type_convert<C0DataType, Acc0DataType>, Acc0BlockTileType{}),
b1_dram_block_window)){};
const auto f_max = [](auto v0, auto v1) { return max(v0, v1); };
const auto f_sum = [](auto v0, auto v1) { return v0 + v1; };
// init Acc1
tile_elementwise_inout([](auto& acc1) { acc1 = 0; }, acc1_block_tile);
// m, l tile
auto m = decltype(block_tile_reduce<Acc0DataType>(
Acc0BlockTileType{}, Sequence<1>{}, f_max, Acc0DataType{0})){};
// init m, l
auto l = make_static_distributed_tensor<Acc0DataType>(m.GetTileDistribution());
tile_elementwise_inout([](auto& m_v) { m_v = NumericLimits<Acc0DataType>::Lowest(); }, m);
tile_elementwise_inout([](auto& l_v) { l_v = 0; }, l);
index_t iN0 = 0;
do
{
// S[i][j] = Q[i] * K[j]
const auto acc0_block_tile = block_gemm0_pipeline(
a0_dram_block_window, b0_dram_block_window, K0 / kK0PerBlock, p_smem_char);
// rowmax(S[i][j])
auto m_local = block_tile_reduce<Acc0DataType>(
acc0_block_tile, Sequence<1>{}, f_max, NumericLimits<Acc0DataType>::Lowest());
block_tile_reduce_sync(m_local, f_max);
// m[i][j-1]
const auto m_old = m;
// m[i][j]
tile_elementwise_inout(
[](auto& m_v, auto m_old_v, auto m_local_v) { m_v = max(m_old_v, m_local_v); },
m,
m_old,
m_local);
// P[i][j]
auto p =
make_static_distributed_tensor<Acc0DataType>(acc0_block_tile.GetTileDistribution());
constexpr auto p_spans = decltype(p)::GetDistributedSpans();
sweep_tile_span(p_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto m_v = m.GetElementFromTileDistributedIndices(i_idx);
sweep_tile_span(p_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto s_v = acc0_block_tile.GetElementFromTileDistributedIndices(i_j_idx);
const auto p_v = math::exp(s_v - m_v);
p.SetElementFromTileDistributedIndices(i_j_idx, p_v);
});
});
// rowsum(P[i][j])
auto rowsum_p =
block_tile_reduce<Acc0DataType>(p, Sequence<1>{}, f_sum, Acc0DataType{0});
block_tile_reduce_sync(rowsum_p, f_sum);
// l[i][j], O[i][j]
sweep_tile_span(p_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto m_old_v = m_old.GetElementFromTileDistributedIndices(i_idx);
const auto m_v = m.GetElementFromTileDistributedIndices(i_idx);
const auto l_old_v = l.GetElementFromTileDistributedIndices(i_idx);
const auto tmp = math::exp(m_old_v - m_v);
const auto tmp2 = 1 / tmp;
auto l_v = tmp * l_old_v + rowsum_p.GetElementFromTileDistributedIndices(i_idx);
l.SetElementFromTileDistributedIndices(i_idx, l_v);
sweep_tile_span(p_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// O[i][j]
const auto o_old_v =
acc1_block_tile.GetElementFromTileDistributedIndices(i_j_idx);
#if 0 // debug
// this use the same equation from FA v2 paper, but produce -nan
const auto o_v = o_old_v * tmp2;
#elif 1
// this use different equation from FA v2 paper, but produce correct result
(void) tmp2;
const auto o_v = o_old_v * tmp;
#endif
acc1_block_tile.SetElementFromTileDistributedIndices(i_j_idx, o_v);
});
});
// type cast p into a1
const auto c0_block_tile =
tile_elementwise_in(type_convert<C0DataType, Acc0DataType>, p);
// Block GEMM1: acc1 += c0 * b1
{
// load b1
const auto b1_block_tile = load_tile(b1_dram_block_window);
// wait for block gemm0 pipeline to finish
block_sync_lds();
store_tile(b1_lds_block_window, b1_block_tile);
// wait for store_tile to finish
block_sync_lds();
// acc1 += c0 * b1
block_gemm1(acc1_block_tile, c0_block_tile, b1_lds_block_window);
// wait for block gemm1 to finish
block_sync_lds();
}
// move tile windows
move_tile_window(b0_dram_block_window, {kN0PerBlock, 0});
move_tile_window(b1_dram_block_window, {0, kN0PerBlock});
iN0 += kN0PerBlock;
} while(iN0 < N0);
// o[i][J-1]
constexpr auto o_spans = decltype(acc1_block_tile)::GetDistributedSpans();
sweep_tile_span(o_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto l_v = l.GetElementFromTileDistributedIndices(i_idx);
const auto tmp = 1 / l_v;
sweep_tile_span(o_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto o_v = acc1_block_tile.GetElementFromTileDistributedIndices(i_j_idx);
const auto o_new_v = o_v * tmp;
acc1_block_tile.SetElementFromTileDistributedIndices(i_j_idx, o_new_v);
});
});
// type cast acc1 into c1
const auto c1_block_tile =
tile_elementwise_in(type_convert<C1DataType, Acc1DataType>, acc1_block_tile);
// store c1
auto c1_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_c1, make_tuple(M0, N1), make_tuple(Ldc1, 1), Number<32>{}, Number<1>{});
auto c1_dram_window =
make_tile_window(c1_dram_grid,
make_tuple(Number<kM0PerBlock>{}, Number<kN1PerBlock>{}),
{iM0, iN1},
c1_block_tile.GetTileDistribution());
store_tile(c1_dram_window, c1_block_tile);
} }
}; };
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
// O[M0, N1] = P[M0, N0] * V[N1, N0]
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType,
ck::index_t kBlockSize,
ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock>
struct GemmSoftmaxGemmImpl
{
// block gemm0 pipeline
using BlockGemm0Pipeline = ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2<
ck::tile_program::block::BlockGemmPipelineProblem<
QDataType,
KDataType,
SaccDataType,
kBlockSize,
ck::tile_program::TileGemmShape<kM0PerBlock, kN0PerBlock, kK0PerBlock>>,
ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>;
// block gemm1
using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1<
ck::tile_program::block::BlockGemmARegBSmemCRegV1Problem<
PDataType,
VDataType,
OaccDataType,
kBlockSize,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kN0PerBlock>>,
ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>;
#if 0
// 2d
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_desc;
}
#else
// fake XOR
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
using namespace ck;
using BDataType = VDataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(kNPerBlock / 2, 2, kKPerBlock), Number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_desc_d1_d2_d3,
make_tuple(make_xor_transform(make_tuple(kNPerBlock / 2, kKPerBlock), kK1),
make_pass_through_transform(2)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto b_lds_desc_n_k = transform_tensor_descriptor(
b_lds_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(kNPerBlock / 2, 2)),
make_pass_through_transform(kKPerBlock)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_desc_n_k;
}
#endif
__device__ static constexpr auto MakeVDramTileDistribution()
{
using namespace ck;
using namespace ck::tile_program;
using BDataType = VDataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
}
__device__ static constexpr ck::index_t GetStaticLdsSize()
{
using namespace ck;
return math::max(BlockGemm0Pipeline::GetStaticLdsSize(),
static_cast<index_t>(MakeVLdsBlockDescriptor().GetElementSpaceSize() *
sizeof(VDataType)));
}
__device__ void operator()(const QDataType* q_ptr,
const KDataType* k_ptr,
const VDataType* v_ptr,
ODataType* o_ptr,
const ck::index_t M0,
const ck::index_t N0,
const ck::index_t K0,
const ck::index_t N1,
const ck::index_t StrideQ,
const ck::index_t StrideK,
const ck::index_t StrideV,
const ck::index_t StrideO,
const ck::index_t iM0,
const ck::index_t iN1) const
{
using namespace ck;
using namespace ck::tile_program;
using namespace ck::tile_program::block;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// allocate LDS
__shared__ char smem_ptr[GetStaticLdsSize()];
// Q/K/V DRAM and DRAM window
// FIXME: assume layout Q[M0, K0], K[N0, K0], V[N1, N0], O[M0, N1]
const auto q_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), Number<32>{}, Number<1>{});
const auto k_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), Number<32>{}, Number<1>{});
const auto v_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), Number<32>{}, Number<1>{});
auto q_dram_window = make_tile_window(
q_dram, make_tuple(Number<kM0PerBlock>{}, Number<kK0PerBlock>{}), {iM0, 0});
auto k_dram_window = make_tile_window(
k_dram, make_tuple(Number<kN0PerBlock>{}, Number<kK0PerBlock>{}), {0, 0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}),
{iN1, 0},
MakeVDramTileDistribution());
// V LDS and LDS window
// V LDS occupies the same LDS allocation Q/K LDS
auto v_lds = make_tensor_view<AddressSpaceEnum::Lds>(reinterpret_cast<VDataType*>(smem_ptr),
MakeVLdsBlockDescriptor());
auto v_lds_window = make_tile_window(
v_lds, make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), {0, 0});
// Block GEMM0 pipeline and Block GEMM1
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
constexpr auto gemm1 = BlockGemm1{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SaccBlockTileType =
decltype(gemm0_pipeline(q_dram_window, k_dram_window, 0, nullptr));
using SBlockTileType = decltype(tile_elementwise_in(
type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{}));
using PBlockTileType = decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>,
SaccBlockTileType{}));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm1(PBlockTileType{}, v_dram_window));
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
tile_elementwise_inout([](auto& e) { e = 0; }, o_acc);
tile_elementwise_inout([](auto& e) { e = NumericLimits<SMPLComputeDataType>::Lowest(); },
m);
tile_elementwise_inout([](auto& e) { e = 0; }, l);
// loop over Column of S (J loop)
index_t iN0 = 0;
do
{
// Sacc{j} = Q * K{j}
const auto s_acc =
gemm0_pipeline(q_dram_window, k_dram_window, K0 / kK0PerBlock, smem_ptr);
// S{j}
const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
// m_local = rowmax(S{j})
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s, Sequence<1>{}, f_max, NumericLimits<SMPLComputeDataType>::Lowest());
block_tile_reduce_sync(m_local, f_max);
// m{j-1}
const auto m_old = m;
// m{j}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
// Pcompute{j}
auto p_compute =
make_static_distributed_tensor<SMPLComputeDataType>(s.GetTileDistribution());
constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans();
sweep_tile_span(p_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(p_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]);
});
});
// rowsum(Pcompute{j})
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0});
block_tile_reduce_sync(rowsum_p, f_sum);
// l{j}, Oacc{j}
sweep_tile_span(p_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = math::exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(p_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
// type cast Pcompute{j} into P{j}
const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Block GEMM1: Oacc{j} += P{j} * V{j}
{
// load V{j}
const auto v = load_tile(v_dram_window);
// wait for gemm0 pipeline to finish
block_sync_lds();
store_tile(v_lds_window, v);
// wait for store_tile to finish
block_sync_lds();
// Oacc{j} += P{j} * V{j}
gemm1(o_acc, p, v_lds_window);
// wait for gemm1 to finish
block_sync_lds();
}
// move tile windows
move_tile_window(k_dram_window, {kN0PerBlock, 0});
move_tile_window(v_dram_window, {0, kN0PerBlock});
iN0 += kN0PerBlock;
} while(iN0 < N0);
// O
constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans();
sweep_tile_span(o_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = 1 / l[i_idx];
sweep_tile_span(o_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
// type cast Oacc into O
const auto o = tile_elementwise_in(type_convert<ODataType, OaccDataType>, o_acc);
// O DRAM and O DRAM window
auto o_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
o_ptr, make_tuple(M0, N1), make_tuple(StrideO, 1), Number<32>{}, Number<1>{});
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(Number<kM0PerBlock>{}, Number<kN1PerBlock>{}),
{iM0, iN1},
o.GetTileDistribution());
// store O
store_tile(o_dram_window, o);
}
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_batched_gemm(const Tensor<ADataType>& a_b_m_k,
const Tensor<BDataType>& b_b_n_k,
Tensor<CDataType>& c_b_m_n)
{
const int N = b_b_n_k.mDesc.GetLengths()[1];
const int K = b_b_n_k.mDesc.GetLengths()[2];
auto f = [&](auto batch, auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a = a_b_m_k(batch, m, k);
BDataType v_b = b_b_n_k(batch, n, k);
v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
c_b_m_n(batch, m, n) = ck::type_convert<CDataType>(v_acc);
}
};
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.GetLengths()[0], c_b_m_n.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp"
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_batched_softmax(const Tensor<ADataType>& a_b_m_n, Tensor<BDataType>& b_b_m_n)
{
const int N = a_b_m_n.mDesc.GetLengths()[2];
auto f = [&](auto batch, auto m) {
AccDataType v_max = ck::NumericLimits<ADataType>::Lowest();
// max
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_b_m_n(batch, m, n);
v_max = v_max < v_a ? v_a : v_max;
}
AccDataType v_exp_sum = 0;
// sum
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_b_m_n(batch, m, n);
v_exp_sum += ck::math::exp(v_a - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_b_m_n(batch, m, n);
b_b_m_n(batch, m, n) = ck::math::exp(v_a - v_max) / v_exp_sum;
}
};
make_ParallelTensorFunctor(f, b_b_m_n.mDesc.GetLengths()[0], b_b_m_n.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
...@@ -6,28 +6,30 @@ ...@@ -6,28 +6,30 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
template <typename ADataType, typename BDataType, typename CDataType, typename AccDataType> template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_gemm(const Tensor<ADataType>& a_m_k, void reference_gemm(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_n_k, const Tensor<BDataType>& b_n_k,
Tensor<CDataType>& c_m_n) Tensor<CDataType>& c_m_n)
{ {
auto f_mk_kn_mn = [&](auto m, auto n) { const int N = b_n_k.mDesc.GetLengths()[0];
const int K = a_m_k.mDesc.GetLengths()[1]; const int K = b_n_k.mDesc.GetLengths()[1];
AccDataType v_acc = 0; auto f = [&](auto m) {
for(int n = 0; n < N; ++n)
for(int k = 0; k < K; ++k)
{ {
ADataType v_a = a_m_k(m, k); AccDataType v_acc = 0;
BDataType v_b = b_n_k(n, k);
v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); for(int k = 0; k < K; ++k)
} {
ADataType v_a = a_m_k(m, k);
BDataType v_b = b_n_k(n, k);
c_m_n(m, n) = ck::type_convert<CDataType>(v_acc); v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
c_m_n(m, n) = ck::type_convert<CDataType>(v_acc);
}
}; };
make_ParallelTensorFunctor(f_mk_kn_mn, make_ParallelTensorFunctor(f, c_m_n.mDesc.GetLengths()[0])(std::thread::hardware_concurrency());
c_m_n.mDesc.GetLengths()[0],
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
} }
...@@ -143,23 +143,20 @@ struct Softmax ...@@ -143,23 +143,20 @@ struct Softmax
sweep_tile_span(a_spans[I0], [&](auto idx0) { sweep_tile_span(a_spans[I0], [&](auto idx0) {
constexpr auto m_idx = make_tuple(idx0); constexpr auto m_idx = make_tuple(idx0);
const auto v_max = max_block_tensor.GetElementFromTileDistributedIndices(m_idx); const auto v_max = max_block_tensor[m_idx];
AccDataType v_exp_sum = AccDataType v_exp_sum = exp_sum_block_tensor[m_idx];
exp_sum_block_tensor.GetElementFromTileDistributedIndices(m_idx);
sweep_tile_span(a_spans[I1], [&](auto idx1) { sweep_tile_span(a_spans[I1], [&](auto idx1) {
constexpr auto m_n_idx = make_tuple(idx0, idx1); constexpr auto m_n_idx = make_tuple(idx0, idx1);
const auto v_a = a_block_tensor.GetElementFromTileDistributedIndices(m_n_idx); const auto v_a = a_block_tensor[m_n_idx];
(void)v_max;
// exp and sum // exp and sum
v_exp_sum += math::exp(v_a - v_max); v_exp_sum += math::exp(v_a - v_max);
}); });
exp_sum_block_tensor.SetElementFromTileDistributedIndices(m_idx, v_exp_sum); exp_sum_block_tensor(m_idx) = v_exp_sum;
}); });
move_tile_window(a_block_window, {0, kNPerBlock}); move_tile_window(a_block_window, {0, kNPerBlock});
...@@ -196,21 +193,20 @@ struct Softmax ...@@ -196,21 +193,20 @@ struct Softmax
sweep_tile_span(a_spans[I0], [&](auto idx0) { sweep_tile_span(a_spans[I0], [&](auto idx0) {
constexpr auto m_idx = make_tuple(idx0); constexpr auto m_idx = make_tuple(idx0);
const auto v_max = max_block_tensor.GetElementFromTileDistributedIndices(m_idx); const auto v_max = max_block_tensor[m_idx];
const auto v_exp_sum = const auto v_exp_sum = exp_sum_block_tensor[m_idx];
exp_sum_block_tensor.GetElementFromTileDistributedIndices(m_idx);
sweep_tile_span(a_spans[I1], [&](auto idx1) { sweep_tile_span(a_spans[I1], [&](auto idx1) {
constexpr auto m_n_idx = make_tuple(idx0, idx1); constexpr auto m_n_idx = make_tuple(idx0, idx1);
const auto v_a = a_block_tensor.GetElementFromTileDistributedIndices(m_n_idx); const auto v_a = a_block_tensor[m_n_idx];
// exp // exp
const BDataType v_b = const BDataType v_b =
type_convert<BDataType>(math::exp(v_a - v_max) / v_exp_sum); type_convert<BDataType>(math::exp(v_a - v_max) / v_exp_sum);
b_block_tensor.SetElementFromTileDistributedIndices(m_n_idx, v_b); b_block_tensor(m_n_idx) = v_b;
}); });
}); });
......
...@@ -160,11 +160,11 @@ float launch_kernel(const StreamConfig& stream_config, ...@@ -160,11 +160,11 @@ float launch_kernel(const StreamConfig& stream_config,
KernelImpl kernel_impl, KernelImpl kernel_impl,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
std::size_t lds_byte, std::size_t dynamic_smem_byte,
Args... args) Args... args)
{ {
const auto kernel = kernel_wrapper<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>; const auto kernel = kernel_wrapper<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, kernel, grid_dim, block_dim, lds_byte, kernel_impl, args...); stream_config, kernel, grid_dim, block_dim, dynamic_smem_byte, kernel_impl, args...);
} }
...@@ -161,18 +161,18 @@ __device__ void block_tile_reduce(AccDistributedTensor_& acc_tensor, ...@@ -161,18 +161,18 @@ __device__ void block_tile_reduce(AccDistributedTensor_& acc_tensor,
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0); constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
auto acc = acc_tensor.GetElementFromTileDistributedIndices(acc_dstr_idx); auto acc = acc_tensor[acc_dstr_idx];
// FIXME // FIXME
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
const auto in = in_tensor.GetElementFromTileDistributedIndices(in_dstr_idx); const auto in = in_tensor[in_dstr_idx];
acc = reduce_func(acc, in); acc = reduce_func(acc, in);
}); });
acc_tensor.SetElementFromTileDistributedIndices(acc_dstr_idx, acc); acc_tensor(acc_dstr_idx) = acc;
}); });
#endif #endif
} }
......
...@@ -105,6 +105,31 @@ struct StaticDistributedTensor ...@@ -105,6 +105,31 @@ struct StaticDistributedTensor
}); });
} }
template <typename TileDistributedIndices>
__host__ __device__ constexpr const DataType& operator[](TileDistributedIndices) const
{
static_assert(is_static_v<TileDistributedIndices>,
"wrong! Tile Distributed Indices should be static");
constexpr auto y_idx =
GetTileDistribution().GetYIndicesFromDistributedIndices(TileDistributedIndices{});
return thread_buf_[Number<ThreadTensorDesc{}.CalculateOffset(y_idx)>{}];
}
template <typename TileDistributedIndices>
__host__ __device__ constexpr DataType& operator()(TileDistributedIndices)
{
static_assert(is_static_v<TileDistributedIndices>,
"wrong! Tile Distributed Indices should be static");
constexpr auto y_idx =
GetTileDistribution().GetYIndicesFromDistributedIndices(TileDistributedIndices{});
return thread_buf_(Number<ThreadTensorDesc{}.CalculateOffset(y_idx)>{});
}
#if 0
template <index_t... Ys> template <index_t... Ys>
__host__ __device__ auto GetElementFromYsIndex(Sequence<Ys...> idx_ys) const __host__ __device__ auto GetElementFromYsIndex(Sequence<Ys...> idx_ys) const
{ {
...@@ -116,7 +141,6 @@ struct StaticDistributedTensor ...@@ -116,7 +141,6 @@ struct StaticDistributedTensor
{ {
thread_buf_(Number<ThreadTensorDesc{}.CalculateOffset(idx_ys)>{}) = v; thread_buf_(Number<ThreadTensorDesc{}.CalculateOffset(idx_ys)>{}) = v;
} }
template <typename TileDistributedIndices> template <typename TileDistributedIndices>
__host__ __device__ auto GetElementFromTileDistributedIndices(TileDistributedIndices) const __host__ __device__ auto GetElementFromTileDistributedIndices(TileDistributedIndices) const
{ {
...@@ -139,6 +163,7 @@ struct StaticDistributedTensor ...@@ -139,6 +163,7 @@ struct StaticDistributedTensor
return SetElementFromYsIndex(y_idx, v); return SetElementFromYsIndex(y_idx, v);
} }
#endif
// //
StaticBuffer<AddressSpaceEnum::Vgpr, DataType, kThreadElementSpaceSize, true> thread_buf_; StaticBuffer<AddressSpaceEnum::Vgpr, DataType, kThreadElementSpaceSize, true> thread_buf_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment