Commit 55cdf2b9 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/layernorm_fusion

parents 4b59b5c9 b098b71b
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <type_traits>
namespace ck_tile {
namespace element_wise {
#if 0
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct PassThrough
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
const ck_tile::bf16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
const ck_tile::fp16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
{
y = type_convert<int32_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
{
y = type_convert<float>(x);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
const ck_tile::fp8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
const ck_tile::bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
}
};
#if 0
struct UnaryConvert
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
#endif
struct Scale
{
CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = ck_tile::type_convert<Y>(ck_tile::type_convert<float>(x) * scale_);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::fp16_t>(scale_) * x;
};
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
const float x_tmp = ck_tile::type_convert<float>(x);
const float y_tmp = scale_ * x_tmp;
y = ck_tile::type_convert<ck_tile::bf16_t>(y_tmp);
};
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = scale_ * x;
};
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
{
y = scale_ * x;
};
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = ck_tile::type_convert<int8_t>(scale_ * ck_tile::type_convert<float>(x));
};
float scale_;
};
struct ScaleAndResetNaNToMinusInfinity
{
CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = ck_tile::isnan(x) ? -numeric<float>::infinity() : scale_ * x;
};
float scale_;
};
struct UnaryDivide
{
CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
};
int32_t divider_ = 1;
};
struct UnarySquare
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, ck_tile::fp16_t> ||
std::is_same_v<T, double> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, int4_t>
#endif
,
"Data type is not supported by this operation!");
y = x * x;
};
};
struct UnaryAbs
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
y = ck_tile::abs(x);
};
};
struct UnarySqrt
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"Data type is not supported by this operation!");
y = ck_tile::sqrt(x);
};
};
struct Relu
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0;
}
template <>
CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
float x_f32 = ck_tile::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck_tile::type_convert<ck_tile::bf16_t>(y_f32);
}
};
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct FastGelu
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = __ocml_exp_f32(u);
y = x * ck_tile::rcp(1.f + emu);
}
template <>
CK_TILE_HOST void operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y,
const ck_tile::fp16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y,
const ck_tile::fp16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y,
const ck_tile::bf16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y,
const ck_tile::bf16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::bf16_t>(y_f);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::fp16_t(0.5) * x *
(ck_tile::fp16_t(1) + ck_tile::fp16_t(erf(float(0.70710678118f * x))));
}
};
struct Sigmoid
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = one / (one + ck_tile::exp(-x));
};
};
struct Silu
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
};
};
struct TanH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::tanh(x);
};
};
struct ACos
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::acos(x);
};
};
struct Neg
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::neg(x);
};
};
struct ATan
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::atan(x);
};
};
struct Sin
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::sin(x);
};
};
struct ASinH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::asinh(x);
};
};
struct Cos
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::cos(x);
};
};
struct ACosH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::acosh(x);
};
};
struct Tan
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::tan(x);
};
};
struct ATanH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::atanh(x);
};
};
struct SinH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::sinh(x);
};
};
struct Ceil
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::ceil(x);
};
};
struct Exp
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::exp(x);
};
};
struct CosH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::cosh(x);
};
};
struct Floor
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::floor(x);
};
};
struct Log
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::log(x);
};
};
struct ASin
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::asin(x);
};
};
struct Rcp
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::rcp(x);
};
};
struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
static_assert(std::is_same_v<X, float> || std::is_same_v<X, double> ||
std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
static_assert(std::is_same_v<Y, float> || std::is_same_v<Y, double> ||
std::is_same_v<Y, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck_tile::exp(bx)));
};
const float beta_;
};
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck_tile::log(one + ck_tile::exp(x * casted_alpha)) / casted_alpha;
}
const float alpha_;
};
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck_tile::pow(shifted_scaled_x, casted_gamma);
}
const float alpha_;
const float beta_;
const float gamma_;
};
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck_tile::min(casted_beta, ck_tile::max(casted_alpha, x));
}
const float alpha_;
const float beta_;
};
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
const float alpha_;
};
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck_tile::expm1(x);
}
const float alpha_;
};
struct Logistic
{
Logistic(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = casted_alpha / (one + ck_tile::exp(-x) * casted_alpha);
}
const float alpha_;
};
struct ConvInvscale
{
CK_TILE_HOST_DEVICE
ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
e = type_convert<ck_tile::fp8_t>(c / scale_in_ / scale_wei_ / scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
struct ConvScale
{
CK_TILE_HOST_DEVICE
ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
e = type_convert<ck_tile::fp8_t>(c * scale_in_ * scale_wei_ * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
struct ConvScaleRelu
{
CK_TILE_HOST_DEVICE
ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
float x;
Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
e = type_convert<ck_tile::fp8_t>(x * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
template <typename DstType, typename SrcType>
struct Cast
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const
{
y = ck_tile::type_convert<DstType>(x);
};
};
// support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter
{
};
template <>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4>
{
using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck_tile::fp16_t, 4>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
OutputArray Output;
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
static constexpr uint32_t byte_selector_01 = 0x05010500;
static constexpr uint32_t byte_selector_23 = 0x05030502;
static constexpr uint32_t fp16_adder = 0x64646464;
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck_tile::fp16_t, N>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4> converter;
OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck_tile::fp16_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
static_for<0, N / VEC_WIDTH, 1>{}(
[&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
#endif
} // namespace element_wise
} // namespace ck_tile
......@@ -69,7 +69,8 @@ struct FmhaFwdKernel
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
......@@ -85,7 +86,8 @@ struct FmhaFwdKernel
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
......
......@@ -65,7 +65,8 @@ struct FmhaFwdSplitKVKernel
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
......@@ -81,7 +82,8 @@ struct FmhaFwdSplitKVKernel
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
......@@ -894,7 +896,7 @@ struct FmhaFwdSplitKVKernel
o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o_acc, 1),
number<1>{},
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
return pad_tensor_view(
......
......@@ -26,8 +26,8 @@ struct FmhaFwdSplitKVTilePartitioner
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits,
ck_tile::integer_divide_ceil(hdim_v, kN1) * num_splits,
nhead,
batch_size);
}
......@@ -42,8 +42,9 @@ struct FmhaFwdSplitKVTilePartitioner
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1);
const auto [i_nhead, i_split] = f(blockIdx.y, num_splits);
const auto [mn, i_split] = f(blockIdx.x, num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
......
......@@ -64,6 +64,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentOacc =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
......@@ -252,11 +255,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
......
......@@ -9,11 +9,20 @@
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return static_cast<index_t>(16 / sizeof(OaccDataType));
}
};
} // namespace ck_tile
......@@ -39,8 +39,11 @@ struct BlockFmhaPipelineProblem
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
......@@ -84,8 +87,11 @@ struct BlockFmhaFwdSplitKVPipelineProblem
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
......
......@@ -242,11 +242,11 @@ struct BlockFmhaPipelineQRKSVS
{seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
......
......@@ -314,11 +314,11 @@ struct BlockFmhaPipelineQRKSVSAsync
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
......@@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
......@@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access());
async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
......
......@@ -9,9 +9,10 @@
namespace ck_tile {
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
struct BlockFmhaPipelineQSKSVS
struct [[deprecated]] BlockFmhaPipelineQSKSVS
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
......
......@@ -15,6 +15,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
......@@ -64,13 +65,28 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
if constexpr(1 < Problem::kNumGemm0Warps)
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(MWarp == 1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
}
template <typename Problem>
......@@ -80,7 +96,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
......@@ -129,12 +145,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
};
/// NOTICE: we no-longer use this policy.
template <>
struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
{
static constexpr bool QLoadOnce = false;
......@@ -364,12 +384,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad =
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
// TODO: not correct!
if constexpr(total_pixels > 4)
return 4;
else
return 2;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
}
else
{
......@@ -383,10 +406,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using BlockGemm = remove_cvref_t<decltype(QXPolicy::template GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
using CWarpDstr = typename WG::CWarpDstr;
constexpr auto vec =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{});
return vec;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
......@@ -395,10 +416,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
using CWarpDstr = typename WG::CWarpDstr;
constexpr auto vec =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{});
return vec;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
......@@ -449,44 +468,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return max(SingleKSize, SingleVSize);
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr auto q_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
return q_block_dstr;
}
// TODO: this is used for non async copy desc. unify in the future
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
......@@ -886,36 +873,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
}
template <typename Problem, typename BlockGemm>
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
{
constexpr index_t MPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t NPerBlock = Problem::BlockFmhaShape::kN0;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// Construct C-Block-HostTensor
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
return BlockGemm::MakeCBlockTile().get_tile_distribution();
}
template <typename Problem>
......@@ -972,7 +933,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
Problem::kNumGemm1Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
......
......@@ -21,10 +21,15 @@ struct TileFmhaShape
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumWarps =
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t NumGemm1Warps =
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}));
static_assert(std::is_same_v<Gemm0WarpTile, Gemm1WarpTile>);
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
......
......@@ -8,6 +8,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
......
......@@ -157,7 +157,7 @@ struct BlockGemmARegBRegCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV1DefaultPolicy>
struct BlockGemmARegBSmemCRegOneWarpV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static_assert(kBlockSize == get_warp_size(), "Check failed!");
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK,
// "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1 && NWarp == 1, "Check failed!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = 0;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1 && NWarp == 1, "Check failed!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
static_assert(decltype(c_block_dstr_encode)::NDimP == 1, "Check failed!");
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile
......@@ -181,7 +181,7 @@ struct BlockGemmARegBSmemCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
......
......@@ -182,7 +182,7 @@ struct BlockGemmARegBSmemCRegV2
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
......
......@@ -180,7 +180,7 @@ struct BlockGemmASmemBRegCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
......
......@@ -167,7 +167,7 @@ struct BlockGemmASmemBSmemCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
......
......@@ -119,7 +119,7 @@ struct Layernorm2dFwdPipelineOnePass
// compute inv-std
auto inv_std = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_) + epsilon);
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon));
},
var);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment