Commit 0475a327 authored by dummycoderfe's avatar dummycoderfe
Browse files

Merge branch 'ck_tile/layernorm2d_fwd_optimize' into ck_tile/ln_add_cache_clear

parents c9b961ab 27ff3dec
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
/* /*
// clang-format off // clang-format off
...@@ -42,7 +41,7 @@ template <typename BlockTile_, // block size, seq<M, N> ...@@ -42,7 +41,7 @@ template <typename BlockTile_, // block size, seq<M, N>
typename Vector_, // contiguous pixels(vector size) along seq<M, N> typename Vector_, // contiguous pixels(vector size) along seq<M, N>
index_t BlockSize_ = index_t BlockSize_ =
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
struct Layernorm2dShape struct Generic2dBlockShape
{ {
// block size // block size
static constexpr index_t Block_M = BlockTile_::at(number<0>{}); static constexpr index_t Block_M = BlockTile_::at(number<0>{});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <type_traits>
namespace ck_tile {
namespace element_wise {
#if 0
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct PassThrough
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
const ck_tile::bf16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
const ck_tile::fp16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
{
y = type_convert<int32_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
{
y = type_convert<float>(x);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
const ck_tile::fp8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
const ck_tile::bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
}
};
#if 0
struct UnaryConvert
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
#endif
struct Scale
{
CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = ck_tile::type_convert<Y>(ck_tile::type_convert<float>(x) * scale_);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::fp16_t>(scale_) * x;
};
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
const float x_tmp = ck_tile::type_convert<float>(x);
const float y_tmp = scale_ * x_tmp;
y = ck_tile::type_convert<ck_tile::bf16_t>(y_tmp);
};
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = scale_ * x;
};
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
{
y = scale_ * x;
};
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = ck_tile::type_convert<int8_t>(scale_ * ck_tile::type_convert<float>(x));
};
float scale_;
};
struct ScaleAndResetNaNToMinusInfinity
{
CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = ck_tile::isnan(x) ? -numeric<float>::infinity() : scale_ * x;
};
float scale_;
};
struct UnaryDivide
{
CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
};
int32_t divider_ = 1;
};
struct UnarySquare
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, ck_tile::fp16_t> ||
std::is_same_v<T, double> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, int4_t>
#endif
,
"Data type is not supported by this operation!");
y = x * x;
};
};
struct UnaryAbs
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
y = ck_tile::abs(x);
};
};
struct UnarySqrt
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"Data type is not supported by this operation!");
y = ck_tile::sqrt(x);
};
};
struct Relu
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0;
}
template <>
CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
float x_f32 = ck_tile::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck_tile::type_convert<ck_tile::bf16_t>(y_f32);
}
};
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct FastGelu
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = __ocml_exp_f32(u);
y = x * ck_tile::rcp(1.f + emu);
}
template <>
CK_TILE_HOST void operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y,
const ck_tile::fp16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y,
const ck_tile::fp16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y,
const ck_tile::bf16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y,
const ck_tile::bf16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::bf16_t>(y_f);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::fp16_t(0.5) * x *
(ck_tile::fp16_t(1) + ck_tile::fp16_t(erf(float(0.70710678118f * x))));
}
};
struct Sigmoid
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = one / (one + ck_tile::exp(-x));
};
};
struct Silu
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
};
};
struct TanH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::tanh(x);
};
};
struct ACos
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::acos(x);
};
};
struct Neg
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::neg(x);
};
};
struct ATan
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::atan(x);
};
};
struct Sin
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::sin(x);
};
};
struct ASinH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::asinh(x);
};
};
struct Cos
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::cos(x);
};
};
struct ACosH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::acosh(x);
};
};
struct Tan
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::tan(x);
};
};
struct ATanH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::atanh(x);
};
};
struct SinH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::sinh(x);
};
};
struct Ceil
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::ceil(x);
};
};
struct Exp
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::exp(x);
};
};
struct CosH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::cosh(x);
};
};
struct Floor
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::floor(x);
};
};
struct Log
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::log(x);
};
};
struct ASin
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::asin(x);
};
};
struct Rcp
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::rcp(x);
};
};
struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
static_assert(std::is_same_v<X, float> || std::is_same_v<X, double> ||
std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
static_assert(std::is_same_v<Y, float> || std::is_same_v<Y, double> ||
std::is_same_v<Y, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck_tile::exp(bx)));
};
const float beta_;
};
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck_tile::log(one + ck_tile::exp(x * casted_alpha)) / casted_alpha;
}
const float alpha_;
};
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck_tile::pow(shifted_scaled_x, casted_gamma);
}
const float alpha_;
const float beta_;
const float gamma_;
};
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck_tile::min(casted_beta, ck_tile::max(casted_alpha, x));
}
const float alpha_;
const float beta_;
};
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
const float alpha_;
};
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck_tile::expm1(x);
}
const float alpha_;
};
struct Logistic
{
Logistic(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = casted_alpha / (one + ck_tile::exp(-x) * casted_alpha);
}
const float alpha_;
};
struct ConvInvscale
{
CK_TILE_HOST_DEVICE
ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
e = type_convert<ck_tile::fp8_t>(c / scale_in_ / scale_wei_ / scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
struct ConvScale
{
CK_TILE_HOST_DEVICE
ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
e = type_convert<ck_tile::fp8_t>(c * scale_in_ * scale_wei_ * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
struct ConvScaleRelu
{
CK_TILE_HOST_DEVICE
ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
float x;
Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
e = type_convert<ck_tile::fp8_t>(x * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
template <typename DstType, typename SrcType>
struct Cast
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const
{
y = ck_tile::type_convert<DstType>(x);
};
};
// support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter
{
};
template <>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4>
{
using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck_tile::fp16_t, 4>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
OutputArray Output;
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
static constexpr uint32_t byte_selector_01 = 0x05010500;
static constexpr uint32_t byte_selector_23 = 0x05030502;
static constexpr uint32_t fp16_adder = 0x64646464;
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck_tile::fp16_t, N>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4> converter;
OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck_tile::fp16_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
static_for<0, N / VEC_WIDTH, 1>{}(
[&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
#endif
} // namespace element_wise
} // namespace ck_tile
...@@ -5,4 +5,6 @@ ...@@ -5,4 +5,6 @@
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -9,23 +9,29 @@ namespace ck_tile { ...@@ -9,23 +9,29 @@ namespace ck_tile {
// this epilogue just store out a M*N matrix, row major // this epilogue just store out a M*N matrix, row major
template <typename AccDataType_, typename ODataType_, bool kPadM_, bool kPadN_> template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool UseRawStore_ = true>
struct Default2DEpilogueProblem struct Default2DEpilogueProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_; static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
}; };
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct Default2DEpilogue struct Default2DEpilogue
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
...@@ -36,7 +42,7 @@ struct Default2DEpilogue ...@@ -36,7 +42,7 @@ struct Default2DEpilogue
{ {
// TODO: this is ugly // TODO: this is ugly
if constexpr(kPadM || kPadN) if constexpr(UseRawStore && (kPadM || kPadN))
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence(); buffer_store_fence();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool UseSmoothInputScale_,
bool UseRawStore_ = true,
bool UseMax3_ = false>
struct DynamicQuantEpilogueTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseSmoothInputScale = UseSmoothInputScale_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr bool UseMax3 = UseMax3_;
};
// this epilogue just store out a M*N matrix, row major
template <typename AccDataType_,
typename XScaleDataType_,
typename YScaleDataType_,
typename ODataType_,
typename BlockShape_,
typename Traits_>
struct DynamicQuantEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
using Traits = remove_cvref_t<Traits_>;
};
// TODO: we should put descriptor creation function into policy
template <typename Problem_, typename Policy_ = void>
struct DynamicQuantEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
static constexpr bool kPadM = Problem::Traits::kPadM;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
static constexpr bool UseMax3 = Problem::Traits::UseMax3;
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2d<P_>{};
}
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2dSync<P_>{};
}
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
}
CK_TILE_DEVICE static constexpr auto MakeSmoothInputScaleTileDistribution()
{
using S = BlockShape;
#if 0
// don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<0, 1, 1>,
sequence<0, 0, 3>>{});
#else
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
#endif
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
return reduce_crosswarp_sync.GetSmemSize();
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp,
typename XScaleWindow,
typename YScaleWindow,
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const XScaleWindow& x_scale_window_,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
auto reduce = GetBlockReduce2d();
auto reduce_sync = GetBlockReduce2dSync();
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
const auto x_scale_window =
make_tile_window(x_scale_window_, MakeSmoothInputScaleTileDistribution());
auto x_scale = load_tile(x_scale_window);
auto o_acc_tmp = o_acc_tile;
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<AccDataType>(x_scale[j_idx]);
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
});
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
auto row_absmax = [&]() {
constexpr auto y_size_per_row =
OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(
number<1>{});
if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0)
{
// fast max3+abs implementation
const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
}
else
{
return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_absmax);
}
}();
reduce_sync(row_absmax, f_absmax);
reduce_crosswarp_sync(row_absmax, smem, f_absmax);
// here y_scale is Acc TYpe, need convert to YScale type later
auto y_scale = tile_elementwise_in(
[&](const auto& v_) {
return v_ / type_convert<AccDataType>(numeric<ODataType>::max());
},
row_absmax);
store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id);
});
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
}
}
};
} // namespace ck_tile
...@@ -43,4 +43,5 @@ ...@@ -43,4 +43,5 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -82,10 +82,10 @@ struct FmhaFwdKernel ...@@ -82,10 +82,10 @@ struct FmhaFwdKernel
if (kPadHeadDimV) n += "dv"; if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }(); return n.empty() ? n : std::string("p") + n; }();
return return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_" "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
...@@ -657,7 +657,7 @@ struct FmhaFwdKernel ...@@ -657,7 +657,7 @@ struct FmhaFwdKernel
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0BlockLength>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<kPadSeqLenQ, kPadHeadDimQ>{});
} }
else else
...@@ -724,7 +724,7 @@ struct FmhaFwdKernel ...@@ -724,7 +724,7 @@ struct FmhaFwdKernel
[&]() { [&]() {
if constexpr(FmhaPipeline::kQLoadOnce) if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{}, return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kK0BlockLength>{}); number<FmhaPipeline::kSubQKHeaddim>{});
else else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}); return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(), }(),
......
...@@ -78,10 +78,10 @@ struct FmhaFwdSplitKVKernel ...@@ -78,10 +78,10 @@ struct FmhaFwdSplitKVKernel
if (kPadHeadDimV) n += "dv"; if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }(); return n.empty() ? n : std::string("p") + n; }();
return return
_SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" "_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
...@@ -586,7 +586,7 @@ struct FmhaFwdSplitKVKernel ...@@ -586,7 +586,7 @@ struct FmhaFwdSplitKVKernel
{ {
return pad_tensor_view( return pad_tensor_view(
q_dram_naive, q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0BlockLength>{}), make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<kPadSeqLenQ, kPadHeadDimQ>{});
} }
else else
...@@ -735,7 +735,7 @@ struct FmhaFwdSplitKVKernel ...@@ -735,7 +735,7 @@ struct FmhaFwdSplitKVKernel
[&]() { [&]() {
if constexpr(FmhaPipeline::kQLoadOnce) if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{}, return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kK0BlockLength>{}); number<FmhaPipeline::kSubQKHeaddim>{});
else else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}); return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(), }(),
......
...@@ -34,12 +34,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -34,12 +34,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0; static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -75,22 +76,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -75,22 +76,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -270,7 +271,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -270,7 +271,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -37,12 +37,13 @@ struct BlockFmhaPipelineQRKSVS ...@@ -37,12 +37,13 @@ struct BlockFmhaPipelineQRKSVS
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0; static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS ...@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -38,12 +38,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -38,12 +38,13 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0; static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
...@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync
return 1; return 1;
} }
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking) FmhaMask::IsMasking)
...@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2; return 2;
else else
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -334,12 +335,12 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -334,12 +335,12 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q); // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops); static_assert(1 <= k0_loops);
...@@ -359,7 +360,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -359,7 +360,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access()); async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc, gemm_0(s_acc,
......
...@@ -36,12 +36,12 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -36,12 +36,12 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0; static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -36,12 +36,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -36,12 +36,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0; static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0; static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
...@@ -56,22 +57,22 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -56,22 +57,22 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kQKHeaddim <= 32)
{ {
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kQKHeaddim <= 64)
{ {
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kQKHeaddim <= 128)
{ {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 256) else if constexpr(kQKHeaddim <= 256)
{ {
return 1; return 1;
} }
...@@ -235,7 +236,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -235,7 +236,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
// prefetch K tile // prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
......
...@@ -55,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -55,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr index_t MWarp = config.template at<1>(); constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
...@@ -323,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -323,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template<> struct template<> struct
LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; }; LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; };
template<> struct template<> struct
LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;}; LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;};
// clang-format on // clang-format on
...@@ -332,12 +335,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -332,12 +335,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{ {
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t kN0 = BlockFmhaShape::kN0; constexpr index_t kN0 = BlockFmhaShape::kN0;
constexpr index_t kK0 = BlockFmhaShape::kK0; constexpr index_t kK0 = BlockFmhaShape::kK0;
constexpr index_t kK1 = BlockFmhaShape::kK1; constexpr index_t kK1 = BlockFmhaShape::kK1;
constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{}; return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{};
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,6 +7,20 @@ ...@@ -7,6 +7,20 @@
namespace ck_tile { namespace ck_tile {
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
{
if(len == 96)
return 128;
if(len == 160)
return 256;
// only length of 96, 160 and power-of-two is supported
if(!(len & (len - 1)))
return len;
return 0;
};
template <typename BlockTile_, // sequence<... template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_, typename Gemm0BlockWarps_,
typename Gemm0WarpTile_, typename Gemm0WarpTile_,
...@@ -36,10 +50,12 @@ struct TileFmhaShape ...@@ -36,10 +50,12 @@ struct TileFmhaShape
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kK0BlockLength = static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile) // once (or repeately load Q as a whole tile)
static_assert(kK0BlockLength % kK0 == 0, "kK0BlockLength should be divisible by kK0"); static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
...@@ -37,4 +39,5 @@ ...@@ -37,4 +39,5 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -32,7 +32,7 @@ struct BlockGemmARegBGmemCRegV1 ...@@ -32,7 +32,7 @@ struct BlockGemmARegBGmemCRegV1
BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>, BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
BlockGemmARegBGmemCRegV1DefaultPolicy>; BlockGemmARegBGmemCRegV1DefaultPolicy>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
return sizeof(BDataType) * return sizeof(BDataType) *
Policy::template MakeBSmemBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBSmemBlockDescriptor<Problem>().get_element_space_size();
......
...@@ -24,19 +24,19 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -24,19 +24,19 @@ struct BlockGemmASmemBSmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B // C += A * B
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp> template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockWindowTmp& a_block_window_tmp, const ABlockWindow& a_block_window,
const BBlockWindowTmp& b_block_window_tmp) const const BBlockWindow& b_block_window) const
{ {
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> && static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> && std::is_same_v<BDataType, typename BBlockWindow::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>, std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!"); "wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK, KPerBlock == BlockGemmShape::kK,
...@@ -62,9 +62,9 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -62,9 +62,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct A-warp-window // construct A-warp-window
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}), make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill #if 0 // FIXME: using array will cause register spill
...@@ -97,9 +97,9 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -97,9 +97,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct B-warp-window // construct B-warp-window
auto b_warp_window_tmp = make_tile_window( auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(), b_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}), make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill #if 0 // FIXME: using array will cause register spill
...@@ -200,12 +200,12 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -200,12 +200,12 @@ struct BlockGemmASmemBSmemCRegV1
} }
// C = A * B // C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp> template <typename ABlockTensorTmp, typename BBlockWindow>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const const BBlockWindow& b_block_window) const
{ {
auto c_block_tensor = MakeCBlockTile(); auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); operator()(c_block_tensor, a_block_tensor_tmp, b_block_window);
return c_block_tensor; return c_block_tensor;
} }
}; };
......
...@@ -3,12 +3,13 @@ ...@@ -3,12 +3,13 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
...@@ -17,20 +18,19 @@ struct GemmKernel ...@@ -17,20 +18,19 @@ struct GemmKernel
using TilePartitioner = remove_cvref_t<TilePartitioner_>; using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>; using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize; using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using LayoutA = remove_cvref_t<typename GemmPipeline::LayoutA>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using LayoutB = remove_cvref_t<typename GemmPipeline::LayoutB>; using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using LayoutC = remove_cvref_t<typename GemmPipeline::LayoutC>; // using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{ {
return TilePartitioner::GridSize(M_size, N_size, Batch_size); return TilePartitioner::GridSize(M, N, KBatch);
} }
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
...@@ -40,34 +40,30 @@ struct GemmKernel ...@@ -40,34 +40,30 @@ struct GemmKernel
const void* a_ptr; const void* a_ptr;
const void* b_ptr; const void* b_ptr;
void* c_ptr; void* c_ptr;
index_t M;
float epsilon; index_t N;
index_t K;
ck_tile::index_t M; index_t stride_A;
ck_tile::index_t N; index_t stride_B;
ck_tile::index_t K; index_t stride_C;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
}; };
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr, CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
const void* b_ptr, const void* b_ptr,
void* c_ptr, void* c_ptr,
float epsilon, index_t M,
ck_tile::index_t M, index_t N,
ck_tile::index_t N, index_t K,
ck_tile::index_t K, index_t stride_A,
ck_tile::index_t stride_A, index_t stride_B,
ck_tile::index_t stride_B, index_t stride_C)
ck_tile::index_t stride_C)
{ {
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, M, N, K, stride_A, stride_B, stride_C}; return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
...@@ -78,13 +74,13 @@ struct GemmKernel ...@@ -78,13 +74,13 @@ struct GemmKernel
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr); const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views // Convert pointers to tensor views
auto a_tensor_view = [&]() { auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutA, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, a_start,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{}, number<GemmPipeline::VectorSizeA>{},
number<1>{}); number<1>{});
} }
else else
...@@ -92,29 +88,29 @@ struct GemmKernel ...@@ -92,29 +88,29 @@ struct GemmKernel
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
a_start, a_start,
make_tuple(kargs.M, kargs.K), make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1), make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{}, number<1>{},
number<1>{}); number<1>{});
} }
}(); }();
auto b_tensor_view = [&]() { auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutB, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, b_start,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B), make_tuple(1, kargs.stride_B),
number<GemmPipeline::AlignmentB>{}, number<1>{},
number<1>{}); number<1>{});
} }
else else
{ // Default NK layout {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
b_start, b_start,
make_tuple(kargs.N, kargs.K), make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::AlignmentB>{}, number<GemmPipeline::VectorSizeB>{},
number<1>{}); number<1>{});
} }
}(); }();
...@@ -122,10 +118,12 @@ struct GemmKernel ...@@ -122,10 +118,12 @@ struct GemmKernel
auto a_pad_view = pad_tensor_view( auto a_pad_view = pad_tensor_view(
a_tensor_view, a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < 0, // somehow clang-format is splitting below line into multiple.
GemmPipeline::kPadA ? 1 : 0 > {}); // clang-format off
sequence<false, GemmPipeline::kPadA>{});
// clang-format on
auto ABlockWindow = make_tile_window( auto a_block_window = make_tile_window(
a_pad_view, a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0}); {i_m, 0});
...@@ -133,10 +131,11 @@ struct GemmKernel ...@@ -133,10 +131,11 @@ struct GemmKernel
auto b_pad_view = pad_tensor_view( auto b_pad_view = pad_tensor_view(
b_tensor_view, b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < 0, // clang-format off
GemmPipeline::kPadB ? 1 : 0 > {}); sequence<false, GemmPipeline::kPadB>{});
// clang-format on
auto BBlockWindow = make_tile_window( auto b_block_window = make_tile_window(
b_pad_view, b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0}); {i_n, 0});
...@@ -144,20 +143,21 @@ struct GemmKernel ...@@ -144,20 +143,21 @@ struct GemmKernel
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK; const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
auto acc = GemmPipeline{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr);
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr); // Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() { auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutC, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
c_start, c_start,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C), make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{}, number<GemmPipeline::VectorSizeC>{},
number<1>{}); number<1>{});
} }
else else
...@@ -165,8 +165,8 @@ struct GemmKernel ...@@ -165,8 +165,8 @@ struct GemmKernel
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
c_start, c_start,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{}, number<1>{},
number<1>{}); number<1>{});
} }
}(); }();
...@@ -174,14 +174,15 @@ struct GemmKernel ...@@ -174,14 +174,15 @@ struct GemmKernel
auto c_pad_view = pad_tensor_view( auto c_pad_view = pad_tensor_view(
c_tensor_view, c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < 0, // clang-format off
GemmPipeline::kPadC ? 1 : 0 > {}); sequence<false, GemmPipeline::kPadC>{});
auto CBlockWindow_pad = make_tile_window( // clang-format on
auto c_block_window = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, acc); EpiloguePipeline{}(c_block_window, c_block_tile);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment