Commit 0144b4f4 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'develop' into jizhan/reduce_threadwise_multi_d

parents 300337cd 34f3dfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <chrono>
namespace ck_tile {
struct gpu_timer
{
CK_TILE_HOST gpu_timer()
{
HIP_CHECK_ERROR(hipEventCreate(&start_evt));
HIP_CHECK_ERROR(hipEventCreate(&stop_evt));
}
CK_TILE_HOST ~gpu_timer() noexcept(false)
{
HIP_CHECK_ERROR(hipEventDestroy(start_evt));
HIP_CHECK_ERROR(hipEventDestroy(stop_evt));
}
CK_TILE_HOST void start(const hipStream_t& s)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start_evt, s));
}
CK_TILE_HOST void stop(const hipStream_t& s)
{
HIP_CHECK_ERROR(hipEventRecord(stop_evt, s));
HIP_CHECK_ERROR(hipEventSynchronize(stop_evt));
}
// return in ms
CK_TILE_HOST float duration() const
{
float ms = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&ms, start_evt, stop_evt));
return ms;
}
private:
hipEvent_t start_evt, stop_evt;
};
struct cpu_timer
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST void start(const hipStream_t&)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
start_tick = std::chrono::high_resolution_clock::now();
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST void stop(const hipStream_t&)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
stop_tick = std::chrono::high_resolution_clock::now();
}
// return in ms
CK_TILE_HOST float duration() const
{
double sec =
std::chrono::duration_cast<std::chrono::duration<double>>(stop_tick - start_tick)
.count();
return static_cast<float>(sec * 1e3);
}
private:
std::chrono::time_point<std::chrono::high_resolution_clock> start_tick;
std::chrono::time_point<std::chrono::high_resolution_clock> stop_tick;
};
} // namespace ck_tile
...@@ -23,13 +23,13 @@ VERTICAL: ...@@ -23,13 +23,13 @@ VERTICAL:
[0] 1 2 3 4 5 [0] 1 2 3 4 5
[0] 1 2 3 4 5 [0] 1 2 3 4 5
TOP_LEFT: TOP_LEFT(but negative):
[0] 1 2 3 4 5 [0] 1 2 3 4 5
1 [0] 1 2 3 4 1 [0] 1 2 3 4
2 1 [0] 1 2 3 2 1 [0] 1 2 3
3 2 1 [0] 1 2 3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT: FROM_BOTTOM_RIGHT(but negative):
2 1 [0] 1 2 3 2 1 [0] 1 2 3
3 2 1 [0] 1 2 3 2 1 [0] 1 2
4 3 2 1 [0] 1 4 3 2 1 [0] 1
...@@ -54,7 +54,7 @@ struct Alibi ...@@ -54,7 +54,7 @@ struct Alibi
index_t x_total_, index_t x_total_,
AlibiMode mode_ = AlibiMode::VERTICAL) AlibiMode mode_ = AlibiMode::VERTICAL)
{ {
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope; slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_;
shift_left_up = [&]() { shift_left_up = [&]() {
if(RowMajor) if(RowMajor)
......
...@@ -76,7 +76,7 @@ struct FmhaFwdKernel ...@@ -76,7 +76,7 @@ struct FmhaFwdKernel
return n.empty() ? n : std::string("p") + n; }(); return n.empty() ? n : std::string("p") + n; }();
return return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
...@@ -702,7 +702,7 @@ struct FmhaFwdKernel ...@@ -702,7 +702,7 @@ struct FmhaFwdKernel
else else
{ {
return Alibi<SaccDataType, true>{ return Alibi<SaccDataType, true>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL}; slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
} }
} }
else else
......
...@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner ...@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, static constexpr const char* name = "shb";
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t hdim_v_) ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
...@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner ...@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner
} }
}; };
template <typename BlockFmhaShape_>
using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner<BlockFmhaShape_>;
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner_HBS
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
static constexpr const char* name = "hbs";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(nhead_,
batch_size_,
ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1));
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -131,74 +131,74 @@ int main() ...@@ -131,74 +131,74 @@ int main()
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5}); 0, 1, 2, 3, 4, 5});
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5,
1, 0, 1, 2, 3, 4, -1, 0, -1, -2, -3, -4,
2, 1, 0, 1, 2, 3, -2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2}); -3, -2, -1, 0, -1, -2});
rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0, -3, -2, -1, 0,
4, 3, 2, 1, -4, -3, -2, -1,
5, 4, 3, 2}); -5, -4, -3, -2});
rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2, -3, -2, -1, 0, -1, -2,
4, 3, 2, 1, 0, 1, -4, -3, -2, -1, 0, -1,
5, 4, 3, 2, 1, 0}); -5, -4, -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5,
1, 2, 3, 4, -1, -2, -3, -4,
0, 1, 2, 3, 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0}); -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5}); 0, 1, 2, 3, 4, 5});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5,
1, 0, 1, 2, 3, 4, -1, 0, -1, -2, -3, -4,
2, 1, 0, 1, 2, 3, -2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2}); -3, -2, -1, 0, -1, -2});
rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0, -3, -2, -1, 0,
4, 3, 2, 1, -4, -3, -2, -1,
5, 4, 3, 2}); -5, -4, -3, -2});
rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2, -3, -2, -1, 0, -1, -2,
4, 3, 2, 1, 0, 1, -4, -3, -2, -1, 0, -1,
5, 4, 3, 2, 1, 0}); -5, -4, -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5,
1, 2, 3, 4, -1, -2, -3, -4,
0, 1, 2, 3, 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0}); -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_slope_generation<float>(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625}); rtn &= test_alibi_slope_generation<float>(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625});
rtn &= test_alibi_slope_generation<float>(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692, rtn &= test_alibi_slope_generation<float>(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment