Unverified Commit f3baea0d authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Gemm+softmax+gemm (#9)

* adding gemm+softmax+gemm
parent cfdce3eb
...@@ -4,3 +4,4 @@ add_example_executable(example_gemm gemm.cpp) ...@@ -4,3 +4,4 @@ add_example_executable(example_gemm gemm.cpp)
add_example_executable(example_gemm_gemm gemm_gemm.cpp) add_example_executable(example_gemm_gemm gemm_gemm.cpp)
add_example_executable(example_reduce reduce.cpp) add_example_executable(example_reduce reduce.cpp)
add_example_executable(example_softmax softmax.cpp) add_example_executable(example_softmax softmax.cpp)
add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp)
...@@ -215,7 +215,6 @@ struct GemmGemm ...@@ -215,7 +215,6 @@ struct GemmGemm
// init Acc1 // init Acc1
tile_elementwise_inout([](auto& acc1) { acc1 = 0; }, acc1_block_tile); tile_elementwise_inout([](auto& acc1) { acc1 = 0; }, acc1_block_tile);
#if 0
index_t iN0 = 0; index_t iN0 = 0;
do do
...@@ -255,47 +254,6 @@ struct GemmGemm ...@@ -255,47 +254,6 @@ struct GemmGemm
iN0 += kN0PerBlock; iN0 += kN0PerBlock;
} while(iN0 < N0); } while(iN0 < N0);
#else
index_t iN0 = 0;
do
{
// load b1
const auto b1_block_tile = load_tile(b1_dram_block_window);
// Block GEMM0 pipeline: acc0 = a0 * b0
const auto acc0_block_tile = block_gemm0_pipeline(
a0_dram_block_window, b0_dram_block_window, K0 / kK0PerBlock, p_smem_char);
// type cast acc0 into c0
const auto c0_block_tile =
tile_elementwise_in(type_convert<C0DataType, Acc0DataType>, acc0_block_tile);
// Block GEMM1: acc1 += c0 * b1
{
// wait for block gemm0 pipeline to finish
ps.block_sync_lds();
store_tile(b1_lds_block_window, b1_block_tile);
// wait for store_tile to finish
ps.block_sync_lds();
// acc1 += c0 * b1
block_gemm1(acc1_block_tile, c0_block_tile, b1_lds_block_window);
// wait for block gemm1 to finish
ps.block_sync_lds();
}
// move tile windows
move_tile_window(b0_dram_block_window, {kN0PerBlock, 0});
move_tile_window(b1_dram_block_window, {0, kN0PerBlock});
iN0 += kN0PerBlock;
} while(iN0 < N0);
#endif
// type cast acc1 into c1 // type cast acc1 into c1
const auto c1_block_tile = const auto c1_block_tile =
......
#include <cstring>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "reference_gemm.hpp"
#include "reference_softmax.hpp"
#include "gemm_softmax_gemm.hpp"
int main(int argc, char* argv[])
{
using A0DataType = ck::half_t;
using B0DataType = ck::half_t;
using Acc0DataType = float;
using C0DataType = ck::half_t;
using D0DataType = ck::half_t;
using B1DataType = ck::half_t;
using Acc1DataType = float;
using C1DataType = ck::half_t;
ck::index_t M0 = 13312;
ck::index_t N0 = 4096;
ck::index_t K0 = 128;
ck::index_t N1 = 128;
if(argc == 5)
{
M0 = std::stoi(argv[1]);
N0 = std::stoi(argv[2]);
K0 = std::stoi(argv[3]);
N1 = std::stoi(argv[4]);
}
std::array<ck::index_t, 2> a0_lengths{M0, K0};
std::array<ck::index_t, 2> a0_strides{K0, 1};
std::array<ck::index_t, 2> b0_lengths{N0, K0};
std::array<ck::index_t, 2> b0_strides{K0, 1};
std::array<ck::index_t, 2> c0_lengths{M0, N0};
std::array<ck::index_t, 2> c0_strides{N0, 1};
std::array<ck::index_t, 2> d0_lengths{M0, N0};
std::array<ck::index_t, 2> d0_strides{N0, 1};
std::array<ck::index_t, 2> b1_lengths{N1, N0};
std::array<ck::index_t, 2> b1_strides{N0, 1};
std::array<ck::index_t, 2> c1_lengths{M0, N1};
std::array<ck::index_t, 2> c1_strides{N1, 1};
// host verify
Tensor<A0DataType> a0_host(a0_lengths, a0_strides);
Tensor<B0DataType> b0_host(b0_lengths, b0_strides);
Tensor<C0DataType> c0_host_ref(c0_lengths, c0_strides);
Tensor<D0DataType> d0_host_ref(d0_lengths, d0_strides);
Tensor<B1DataType> b1_host(b1_lengths, b1_strides);
Tensor<C1DataType> c1_host_ref(c1_lengths, c1_strides);
Tensor<C1DataType> c1_host_dev(c1_lengths, c1_strides);
#if 1
ck::utils::FillUniformDistributionIntegerValue<A0DataType>{-3.f, 3.f}(a0_host);
ck::utils::FillUniformDistributionIntegerValue<B0DataType>{-3.f, 3.f}(b0_host);
ck::utils::FillUniformDistributionIntegerValue<B1DataType>{-3.f, 3.f}(b1_host);
#elif 0
ck::utils::FillUniformDistribution<A0DataType>{-3.f, 3.f}(a0_host);
ck::utils::FillUniformDistribution<B0DataType>{-3.f, 3.f}(b0_host);
ck::utils::FillUniformDistribution<B1DataType>{-3.f, 3.f}(b1_host);
#else
ck::utils::FillConstant<A0DataType>{1.0f}(a0_host);
ck::utils::FillConstant<A0DataType>{1.0f}(b0_host);
ck::utils::FillConstant<A0DataType>{1.0f}(b1_host);
#endif
// reference
reference_gemm<A0DataType, B0DataType, C0DataType, float>(a0_host, b0_host, c0_host_ref);
reference_softmax<C0DataType, float, D0DataType>(c0_host_ref, d0_host_ref);
reference_gemm<D0DataType, B1DataType, C1DataType, float>(d0_host_ref, b1_host, c1_host_ref);
DeviceMem a0_buf(sizeof(A0DataType) * a0_host.GetElementSpaceSize());
DeviceMem b0_buf(sizeof(B0DataType) * b0_host.GetElementSpaceSize());
DeviceMem b1_buf(sizeof(B1DataType) * b1_host.GetElementSpaceSize());
DeviceMem c1_buf(sizeof(C1DataType) * c1_host_ref.GetElementSpaceSize());
a0_buf.ToDevice(a0_host.mData.data());
b0_buf.ToDevice(b0_host.mData.data());
b1_buf.ToDevice(b1_host.mData.data());
constexpr ck::index_t kM0PerBlock = 128;
constexpr ck::index_t kN0PerBlock = 128;
constexpr ck::index_t kK0PerBlock = 32;
constexpr ck::index_t kN1PerBlock = 128;
constexpr ck::index_t kBlockSize = 256;
ck::index_t kGridSize = (M0 / kM0PerBlock) * (N1 / kN1PerBlock);
std::cout << "grid size " << kGridSize << std::endl;
float ave_time = launch(ProgramServer{},
GemmSoftmaxGemm<A0DataType,
B0DataType,
Acc0DataType,
C0DataType,
B1DataType,
Acc1DataType,
C1DataType,
kBlockSize,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock>{},
kGridSize,
kBlockSize,
static_cast<A0DataType*>(a0_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_buf.GetDeviceBuffer()),
static_cast<C1DataType*>(c1_buf.GetDeviceBuffer()),
M0,
N0,
K0,
N1,
K0, // Lda0
K0, // Ldb0
N0, // Ldb1
N1); // Ldc1
c1_buf.FromDevice(c1_host_dev.mData.data());
std::size_t flop = std::size_t(2) * M0 * N0 * K0 + std::size_t(2) * M0 * N1 * N0;
std::size_t num_btype = sizeof(A0DataType) * M0 * K0 + sizeof(B0DataType) * N0 * K0 +
sizeof(B1DataType) * N1 * N0 + sizeof(C1DataType) * M0 * N1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
// LogRangeAsType<float>(std::cout << "C1 dev: ", c1_host_dev.mData, ", ", 16, 20) << std::endl;
// LogRangeAsType<float>(std::cout << "C1 ref: ", c1_host_ref.mData, ", ", 16, 20) << std::endl;
return !ck::utils::check_err(c1_host_dev, c1_host_ref);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "tile_program.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
// C0 = A0 * B0
// D0 = softmax(C0)
// C1 = D0 * B1
template <typename A0DataType,
typename B0DataType,
typename Acc0DataType,
typename C0DataType,
typename B1DataType,
typename Acc1DataType,
typename C1DataType,
ck::index_t kBlockSize,
ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock>
struct GemmSoftmaxGemm
{
// block gemm0 pipeline
using BlockGemm0Pipeline = ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2<
ck::tile_program::block::BlockGemmPipelineProblem<
A0DataType,
B0DataType,
Acc0DataType,
kBlockSize,
ck::tile_program::TileGemmShape<kM0PerBlock, kN0PerBlock, kK0PerBlock>>,
ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>;
// block gemm1
using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1<
ck::tile_program::block::BlockGemmARegBSmemCRegV1Problem<
C0DataType,
B1DataType,
Acc1DataType,
kBlockSize,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kN0PerBlock>>,
ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>;
#if 0
// 2d
__host__ __device__ static constexpr auto MakeB1LdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_block_desc;
}
#else
// fake XOR
__host__ __device__ static constexpr auto MakeB1LdsBlockDescriptor()
{
using namespace ck;
using BDataType = B1DataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(kNPerBlock / 2, 2, kKPerBlock), Number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_block_desc_d1_d2_d3,
make_tuple(make_xor_transform(make_tuple(kNPerBlock / 2, kKPerBlock), kK1),
make_pass_through_transform(2)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(kNPerBlock / 2, 2)),
make_pass_through_transform(kKPerBlock)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_block_desc_n_k;
}
#endif
__host__ __device__ static constexpr auto MakeB1DramTileDistribution()
{
using namespace ck;
using namespace ck::tile_program;
using BDataType = B1DataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
}
__host__ __device__ static constexpr ck::index_t GetStaticLdsSize()
{
using namespace ck;
return math::max(BlockGemm0Pipeline::GetStaticLdsSize(),
static_cast<index_t>(MakeB1LdsBlockDescriptor().GetElementSpaceSize() *
sizeof(B1DataType)));
}
__host__ __device__ void operator()(ProgramServer& ps,
const A0DataType* p_a0,
const B0DataType* p_b0,
const B1DataType* p_b1,
C1DataType* p_c1,
ck::index_t M0,
ck::index_t N0,
ck::index_t K0,
ck::index_t N1,
ck::index_t Lda0,
ck::index_t Ldb0,
ck::index_t Ldb1,
ck::index_t Ldc1)
{
using namespace ck;
using namespace ck::tile_program;
using namespace ck::tile_program::block;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// FIXME: assume layout A0[M0, K0], B0[N0, K0], B1[N1, N0], C1[M0, N1]
const auto a0_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_a0, make_tuple(M0, K0), make_tuple(Lda0, 1), Number<32>{}, Number<1>{});
const auto b0_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_b0, make_tuple(N0, K0), make_tuple(Ldb0, 1), Number<32>{}, Number<1>{});
const auto b1_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_b1, make_tuple(N1, N0), make_tuple(Ldb1, 1), Number<32>{}, Number<1>{});
// divide problem
const auto id_block = ps.get_block_id();
const auto num_tile_m0 = M0 / kM0PerBlock;
const auto num_tile_n1 = N1 / kN1PerBlock;
const auto block2tile = ps(make_cluster_descriptor(make_tuple(num_tile_m0, num_tile_n1)));
const auto id_tile = block2tile.CalculateBottomIndex(make_tuple(id_block));
const auto iM0 = ps.read_first_lane(id_tile.At<0>() * kM0PerBlock);
const auto iN1 = ps.read_first_lane(id_tile.At<1>() * kN1PerBlock);
__shared__ char p_smem_char[GetStaticLdsSize()];
// A0 DRAM block window
auto a0_dram_block_window = make_tile_window(
a0_dram_grid, make_tuple(Number<kM0PerBlock>{}, Number<kK0PerBlock>{}), {iM0, 0});
// B0 DRAM block window
auto b0_dram_block_window = make_tile_window(
b0_dram_grid, make_tuple(Number<kN0PerBlock>{}, Number<kK0PerBlock>{}), {0, 0});
// Block GEMM0 pipeline
constexpr auto block_gemm0_pipeline = BlockGemm0Pipeline{};
// B1 DRAM window
auto b1_dram_block_window =
make_tile_window(b1_dram_grid,
make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}),
{iN1, 0},
MakeB1DramTileDistribution());
// B1 LDS tensor view: occupies the same LDS allocation as block_gemm0_pipeline
auto b1_lds_block = make_tensor_view<AddressSpaceEnum::Lds>(
reinterpret_cast<B1DataType*>(p_smem_char), MakeB1LdsBlockDescriptor());
auto b1_lds_block_window = make_tile_window(
b1_lds_block, make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), {0, 0});
// Bock GEMM1
constexpr auto block_gemm1 = BlockGemm1{};
// Acc0 tile
using Acc0BlockTileType =
decltype(block_gemm0_pipeline(a0_dram_block_window, b0_dram_block_window, 0, nullptr));
// Acc1 tile
auto acc1_block_tile = decltype(block_gemm1(
tile_elementwise_in(type_convert<C0DataType, Acc0DataType>, Acc0BlockTileType{}),
b1_dram_block_window)){};
const auto f_max = [](auto v0, auto v1) { return max(v0, v1); };
const auto f_sum = [](auto v0, auto v1) { return v0 + v1; };
// init Acc1
tile_elementwise_inout([](auto& acc1) { acc1 = 0; }, acc1_block_tile);
// m, l tile
auto m = decltype(block_tile_reduce<Acc0DataType>(
Acc0BlockTileType{}, Sequence<1>{}, f_max, Acc0DataType{0})){};
// init m, l
auto l = make_static_distributed_tensor<Acc0DataType>(m.GetTileDistribution());
tile_elementwise_inout([](auto& m_v) { m_v = NumericLimits<Acc0DataType>::Lowest(); }, m);
tile_elementwise_inout([](auto& l_v) { l_v = 0; }, l);
index_t iN0 = 0;
do
{
// S[i][j] = Q[i] * K[j]
const auto acc0_block_tile = block_gemm0_pipeline(
a0_dram_block_window, b0_dram_block_window, K0 / kK0PerBlock, p_smem_char);
// rowmax(S[i][j])
auto m_local = block_tile_reduce<Acc0DataType>(
acc0_block_tile, Sequence<1>{}, f_max, NumericLimits<Acc0DataType>::Lowest());
block_tile_reduce_sync(m_local, f_max);
// m[i][j-1]
const auto m_old = m;
// m[i][j]
tile_elementwise_inout(
[](auto& m_v, auto m_old_v, auto m_local_v) { m_v = max(m_old_v, m_local_v); },
m,
m_old,
m_local);
// P[i][j]
auto p =
make_static_distributed_tensor<Acc0DataType>(acc0_block_tile.GetTileDistribution());
constexpr auto p_spans = decltype(p)::GetDistributedSpans();
sweep_tile_span(p_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto m_v = m.GetElementFromTileDistributedIndices(i_idx);
sweep_tile_span(p_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto s_v = acc0_block_tile.GetElementFromTileDistributedIndices(i_j_idx);
const auto p_v = math::exp(s_v - m_v);
p.SetElementFromTileDistributedIndices(i_j_idx, p_v);
});
});
// rowsum(P[i][j])
auto rowsum_p =
block_tile_reduce<Acc0DataType>(p, Sequence<1>{}, f_sum, Acc0DataType{0});
block_tile_reduce_sync(rowsum_p, f_sum);
// l[i][j], O[i][j]
sweep_tile_span(p_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto m_old_v = m_old.GetElementFromTileDistributedIndices(i_idx);
const auto m_v = m.GetElementFromTileDistributedIndices(i_idx);
const auto l_old_v = l.GetElementFromTileDistributedIndices(i_idx);
const auto tmp = math::exp(m_old_v - m_v);
const auto tmp2 = 1 / tmp;
auto l_v = tmp * l_old_v + rowsum_p.GetElementFromTileDistributedIndices(i_idx);
l.SetElementFromTileDistributedIndices(i_idx, l_v);
sweep_tile_span(p_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// O[i][j]
const auto o_old_v =
acc1_block_tile.GetElementFromTileDistributedIndices(i_j_idx);
#if 0 // debug
// this use the same equation from FA v2 paper, but produce -nan
const auto o_v = o_old_v * tmp2;
#elif 1
// this use different equation from FA v2 paper, but produce correct result
(void) tmp2;
const auto o_v = o_old_v * tmp;
#endif
acc1_block_tile.SetElementFromTileDistributedIndices(i_j_idx, o_v);
});
});
// type cast p into a1
const auto c0_block_tile =
tile_elementwise_in(type_convert<C0DataType, Acc0DataType>, p);
// Block GEMM1: acc1 += c0 * b1
{
// load b1
const auto b1_block_tile = load_tile(b1_dram_block_window);
// wait for block gemm0 pipeline to finish
ps.block_sync_lds();
store_tile(b1_lds_block_window, b1_block_tile);
// wait for store_tile to finish
ps.block_sync_lds();
// acc1 += c0 * b1
block_gemm1(acc1_block_tile, c0_block_tile, b1_lds_block_window);
// wait for block gemm1 to finish
ps.block_sync_lds();
}
// move tile windows
move_tile_window(b0_dram_block_window, {kN0PerBlock, 0});
move_tile_window(b1_dram_block_window, {0, kN0PerBlock});
iN0 += kN0PerBlock;
} while(iN0 < N0);
// o[i][J-1]
constexpr auto o_spans = decltype(acc1_block_tile)::GetDistributedSpans();
sweep_tile_span(o_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto l_v = l.GetElementFromTileDistributedIndices(i_idx);
const auto tmp = 1 / l_v;
sweep_tile_span(o_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto o_v = acc1_block_tile.GetElementFromTileDistributedIndices(i_j_idx);
const auto o_new_v = o_v * tmp;
acc1_block_tile.SetElementFromTileDistributedIndices(i_j_idx, o_new_v);
});
});
// type cast acc1 into c1
const auto c1_block_tile =
tile_elementwise_in(type_convert<C1DataType, Acc1DataType>, acc1_block_tile);
// store c1
auto c1_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_c1, make_tuple(M0, N1), make_tuple(Ldc1, 1), Number<32>{}, Number<1>{});
auto c1_dram_window =
make_tile_window(c1_dram_grid,
make_tuple(Number<kM0PerBlock>{}, Number<kN1PerBlock>{}),
{iM0, iN1},
c1_block_tile.GetTileDistribution());
store_tile(c1_dram_window, c1_block_tile);
}
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp"
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_softmax(const Tensor<ADataType>& a_m_n, Tensor<BDataType>& b_m_n)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.GetLengths()[1];
AccDataType v_max = ck::NumericLimits<ADataType>::Lowest();
// max
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_m_n(m, n);
v_max = v_max < v_a ? v_a : v_max;
}
AccDataType v_exp_sum = 0;
// sum
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_m_n(m, n);
v_exp_sum += ck::math::exp(v_a - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_m_n(m, n);
b_m_n(m, n) = ck::math::exp(v_a - v_max) / v_exp_sum;
}
};
make_ParallelTensorFunctor(f, b_m_n.mDesc.GetLengths()[0])(std::thread::hardware_concurrency());
}
...@@ -14,51 +14,14 @@ ...@@ -14,51 +14,14 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "reference_softmax.hpp"
#include "softmax.hpp" #include "softmax.hpp"
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_softmax(const Tensor<ADataType>& a_m_n, Tensor<BDataType>& b_m_n)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.GetLengths()[1];
AccDataType v_max = ck::NumericLimits<ADataType>::Lowest();
// max
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_m_n(m, n);
v_max = v_max < v_a ? v_a : v_max;
}
AccDataType v_exp_sum = 0;
// sum
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_m_n(m, n);
v_exp_sum += ck::math::exp(v_a - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_m_n(m, n);
b_m_n(m, n) = ck::math::exp(v_a - v_max) / v_exp_sum;
}
};
make_ParallelTensorFunctor(f, b_m_n.mDesc.GetLengths()[0])(std::thread::hardware_concurrency());
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using ADataType = float; using ADataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using BDataType = float; using BDataType = ck::half_t;
ck::index_t M = 3328; ck::index_t M = 3328;
ck::index_t N = 4096; ck::index_t N = 4096;
...@@ -118,5 +81,8 @@ int main(int argc, char* argv[]) ...@@ -118,5 +81,8 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
LogRangeAsType<float>(std::cout << "dev: ", b_host_dev.mData, ", ") << std::endl;
LogRangeAsType<float>(std::cout << "ref: ", b_host_ref.mData, ", ") << std::endl;
return !ck::utils::check_err(b_host_dev, b_host_ref); return !ck::utils::check_err(b_host_dev, b_host_ref);
} }
...@@ -26,6 +26,7 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy ...@@ -26,6 +26,7 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
{ {
using namespace ck::tile_program::warp; using namespace ck::tile_program::warp;
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
...@@ -46,6 +47,9 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy ...@@ -46,6 +47,9 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
{ {
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
} }
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
#endif
} }
}; };
......
...@@ -182,10 +182,10 @@ template <typename AccDataType_, ...@@ -182,10 +182,10 @@ template <typename AccDataType_,
index_t... InReduceDims, index_t... InReduceDims,
typename ReduceFunc, typename ReduceFunc,
typename InDataType_> typename InDataType_>
__host__ __device__ auto block_tile_reduce(const InDistributedTensor_& in_tensor, __device__ auto block_tile_reduce(const InDistributedTensor_& in_tensor,
Sequence<InReduceDims...> in_reduce_dims, Sequence<InReduceDims...> in_reduce_dims,
const ReduceFunc& reduce_func, const ReduceFunc& reduce_func,
const InDataType_& reduce_init) const InDataType_& reduce_init)
{ {
using InDataType = typename InDistributedTensor_::DataType; using InDataType = typename InDistributedTensor_::DataType;
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
...@@ -222,6 +222,33 @@ __host__ void block_tile_reduce(AccDistributedTensor_&, ...@@ -222,6 +222,33 @@ __host__ void block_tile_reduce(AccDistributedTensor_&,
{ {
} }
// FIXME: dummy host function for tile program
template <typename AccDataType_,
typename InDistributedTensor_,
index_t... InReduceDims,
typename ReduceFunc,
typename InDataType_>
__host__ auto block_tile_reduce(const InDistributedTensor_&,
Sequence<InReduceDims...>,
const ReduceFunc&,
const InDataType_&)
{
using InDataType = typename InDistributedTensor_::DataType;
using AccDataType = remove_cvref_t<AccDataType_>;
static_assert(is_same_v<InDataType, remove_cvref_t<InDataType_>>, "wrong!");
// declare acc_tensor
constexpr auto acc_dstr = make_static_tile_distribution(
ck::tile_program::detail::make_reduce_tile_distribution_encoding(
InDistributedTensor_::GetTileDistribution().GetStaticTileDistributionEncoding(),
Sequence<InReduceDims...>{}));
auto acc_tensor = make_static_distributed_tensor<AccDataType>(acc_dstr);
return acc_tensor;
}
// FIXME: dummy host function for tile program // FIXME: dummy host function for tile program
template <typename AccDistributedTensor_, typename ReduceFunc> template <typename AccDistributedTensor_, typename ReduceFunc>
__host__ void block_tile_reduce_sync(AccDistributedTensor_&, const ReduceFunc&) __host__ void block_tile_reduce_sync(AccDistributedTensor_&, const ReduceFunc&)
......
...@@ -26,7 +26,7 @@ __host__ __device__ void tile_elementwise_inout(const InOutElementFunc& inout_el ...@@ -26,7 +26,7 @@ __host__ __device__ void tile_elementwise_inout(const InOutElementFunc& inout_el
type_pack_element<0, InOutDstrTensors...>::GetThreadBufferSize(); type_pack_element<0, InOutDstrTensors...>::GetThreadBufferSize();
static_for<0, thread_buffer_size, 1>{}( static_for<0, thread_buffer_size, 1>{}(
[&](auto i) { inout_element_func(inout_dstr_tensors.GetThreadBuffer()(i)...); }); [&](auto i) { inout_element_func(inout_dstr_tensors.GetThreadBuffer().At(i)...); });
} }
template <typename InElementFunc, typename... InDstrTensors> template <typename InElementFunc, typename... InDstrTensors>
......
...@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M16N16K16 = ...@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M16N16K16 =
using WarpGemmMfmaF16F16F32M32N32K16 = using WarpGemmMfmaF16F16F32M32N32K16 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8, WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include <iomanip>
#include <numeric> #include <numeric>
#include <thread> #include <thread>
#include <utility> #include <utility>
...@@ -19,7 +20,11 @@ ...@@ -19,7 +20,11 @@
#include "ck/library/utility/ranges.hpp" #include "ck/library/utility/ranges.hpp"
template <typename Range> template <typename Range>
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) std::ostream& LogRange(std::ostream& os,
Range&& range,
std::string delim,
int precision = std::cout.precision(),
int width = 0)
{ {
bool first = true; bool first = true;
for(auto&& v : range) for(auto&& v : range)
...@@ -28,13 +33,17 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) ...@@ -28,13 +33,17 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
first = false; first = false;
else else
os << delim; os << delim;
os << v; os << std::setw(width) << std::setprecision(precision) << v;
} }
return os; return os;
} }
template <typename T, typename Range> template <typename T, typename Range>
std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) std::ostream& LogRangeAsType(std::ostream& os,
Range&& range,
std::string delim,
int precision = std::cout.precision(),
int width = 0)
{ {
bool first = true; bool first = true;
for(auto&& v : range) for(auto&& v : range)
...@@ -43,7 +52,7 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) ...@@ -43,7 +52,7 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
first = false; first = false;
else else
os << delim; os << delim;
os << static_cast<T>(v); os << std::setw(width) << std::setprecision(precision) << static_cast<T>(v);
} }
return os; return os;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment