Commit 0144b4f4 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'develop' into jizhan/reduce_threadwise_multi_d

parents 300337cd 34f3dfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <chrono>
namespace ck_tile {
struct gpu_timer
{
CK_TILE_HOST gpu_timer()
{
HIP_CHECK_ERROR(hipEventCreate(&start_evt));
HIP_CHECK_ERROR(hipEventCreate(&stop_evt));
}
CK_TILE_HOST ~gpu_timer() noexcept(false)
{
HIP_CHECK_ERROR(hipEventDestroy(start_evt));
HIP_CHECK_ERROR(hipEventDestroy(stop_evt));
}
CK_TILE_HOST void start(const hipStream_t& s)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start_evt, s));
}
CK_TILE_HOST void stop(const hipStream_t& s)
{
HIP_CHECK_ERROR(hipEventRecord(stop_evt, s));
HIP_CHECK_ERROR(hipEventSynchronize(stop_evt));
}
// return in ms
CK_TILE_HOST float duration() const
{
float ms = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&ms, start_evt, stop_evt));
return ms;
}
private:
hipEvent_t start_evt, stop_evt;
};
struct cpu_timer
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST void start(const hipStream_t&)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
start_tick = std::chrono::high_resolution_clock::now();
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST void stop(const hipStream_t&)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
stop_tick = std::chrono::high_resolution_clock::now();
}
// return in ms
CK_TILE_HOST float duration() const
{
double sec =
std::chrono::duration_cast<std::chrono::duration<double>>(stop_tick - start_tick)
.count();
return static_cast<float>(sec * 1e3);
}
private:
std::chrono::time_point<std::chrono::high_resolution_clock> start_tick;
std::chrono::time_point<std::chrono::high_resolution_clock> stop_tick;
};
} // namespace ck_tile
...@@ -23,13 +23,13 @@ VERTICAL: ...@@ -23,13 +23,13 @@ VERTICAL:
[0] 1 2 3 4 5 [0] 1 2 3 4 5
[0] 1 2 3 4 5 [0] 1 2 3 4 5
TOP_LEFT: TOP_LEFT(but negative):
[0] 1 2 3 4 5 [0] 1 2 3 4 5
1 [0] 1 2 3 4 1 [0] 1 2 3 4
2 1 [0] 1 2 3 2 1 [0] 1 2 3
3 2 1 [0] 1 2 3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT: FROM_BOTTOM_RIGHT(but negative):
2 1 [0] 1 2 3 2 1 [0] 1 2 3
3 2 1 [0] 1 2 3 2 1 [0] 1 2
4 3 2 1 [0] 1 4 3 2 1 [0] 1
...@@ -54,7 +54,7 @@ struct Alibi ...@@ -54,7 +54,7 @@ struct Alibi
index_t x_total_, index_t x_total_,
AlibiMode mode_ = AlibiMode::VERTICAL) AlibiMode mode_ = AlibiMode::VERTICAL)
{ {
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope; slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_;
shift_left_up = [&]() { shift_left_up = [&]() {
if(RowMajor) if(RowMajor)
......
...@@ -76,7 +76,7 @@ struct FmhaFwdKernel ...@@ -76,7 +76,7 @@ struct FmhaFwdKernel
return n.empty() ? n : std::string("p") + n; }(); return n.empty() ? n : std::string("p") + n; }();
return return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
...@@ -702,7 +702,7 @@ struct FmhaFwdKernel ...@@ -702,7 +702,7 @@ struct FmhaFwdKernel
else else
{ {
return Alibi<SaccDataType, true>{ return Alibi<SaccDataType, true>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL}; slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
} }
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment