Commit 036c5234 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents 22995e9a 7843a8a7
...@@ -30,27 +30,29 @@ args: ...@@ -30,27 +30,29 @@ args:
-mode kernel mode. 0:batch, 1:group (default:0) -mode kernel mode. 0:batch, 1:group (default:0)
-b batch size (default:2) -b batch size (default:2)
-h num of head, for q (default:8) -h num of head, for q (default:8)
-h_k num of head, for k/v, 0 means equal to h (default:0) -h_k num of head, for k/v, -1 means equal to h (default:-1)
if not equal to h, then this is GQA/MQA case if not equal to h, then this is GQA/MQA case
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
-s_k seqlen_k, 0 means equal to s (default:0) -s_k seqlen_k, -1 means equal to s (default:-1)
-d head dim for q, k (default:128) -d head dim for q, k (default:128)
-d_v head dim for v, 0 means equal to d (default:0) -d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
note when squant=1, this value will be modified by range_q/k note when squant=1, this value will be modified by range_q/k
-range_q per-tensor quantization range of q. used if squant=1. (default:2) -range_q per-tensor quantization range of q. used if squant=1. (default:16)
-range_k per-tensor quantization range of k. used if squant=1. (default:2) -range_k per-tensor quantization range of k. used if squant=1. (default:16)
-range_v per-tensor quantization range of v. used if squant=1. (default:2) -range_v per-tensor quantization range of v. used if squant=1. (default:16)
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:2) -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
-squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0) -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p, 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
scale_o according to range_q, range_k, range_v, range_p, range_o calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
-iperm permute input (default:1) -iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1) -operm permute output (default:1)
-bias add bias or not (default:0) -bias n or 0, no bias (default:n)
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
a(libi) or 2, alibi with 1*h. a:1, b*h
-prec data type. fp16/bf16/fp8/bf8 (default:fp16) -prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
't', top-left causal mask, 'b', bottom-r causal mask 't', top-left causal mask, 'b', bottom-r causal mask
...@@ -59,11 +61,14 @@ args: ...@@ -59,11 +61,14 @@ args:
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
-lse 0 not store lse, 1 store lse (default:0) -lse 0 not store lse, 1 store lse (default:0)
-kname if set to 1 will print kernel name (default:0) -kname if set to 1 will print kernel name (default:0)
-init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1) -init init method. ui, uniform random int, ni, normalized random int (default:uf)
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
``` ```
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
...@@ -85,6 +90,9 @@ If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prov ...@@ -85,6 +90,9 @@ If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prov
### attention bias ### attention bias
Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number. Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number.
### alibi
alibi is supported
### lse ### lse
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1`
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
// keep sync with BlockAttentionBiasEnum
enum class bias_enum
{
no_bias = 0,
elementwise_bias = 1,
alibi = 2,
};
struct bias_info
{
bias_enum type;
/*
* simple dispatch logic
*
* if type == elementwise_bias:
* if rank_info == 0:
* bias is 1*1*s*s
* elif rank_info == 1:
* bias is 1*h*s*s
* elif rank_info == 2:
* bias is b*h*s*s
*
* elif type == alibi:
* if rank_info == 0:
* alibi in 1*h
* elif rank_info == 1:
* alibi in b*h
*/
int rank_info;
void serialize(std::ostream& os) const
{
if(type == bias_enum::no_bias)
os << "n";
else if(type == bias_enum::elementwise_bias)
{
os << "e";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
else if(type == bias_enum::alibi)
{
os << "alibi";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
}
static bias_info decode(std::string str)
{
bias_info info{bias_enum::no_bias, 0};
if(str == "0" || str == "n")
{
info.type = bias_enum::no_bias;
}
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
str.compare(0, 11, "elementwise") == 0)
{
info.type = bias_enum::elementwise_bias;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
str.compare(0, 5, "alibi") == 0)
{
info.type = bias_enum::alibi;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
return info;
}
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
{
bi.serialize(os);
return os;
}
};
...@@ -41,16 +41,16 @@ auto create_args(int argc, char* argv[]) ...@@ -41,16 +41,16 @@ auto create_args(int argc, char* argv[])
.insert("b", "2", "batch size") .insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q") .insert("h", "8", "num of head, for q")
.insert("h_k", .insert("h_k",
"0", "-1",
"num of head, for k/v, 0 means equal to h\n" "num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case") "if not equal to h, then this is GQA/MQA case")
.insert("s", .insert("s",
"3328", "3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n" "seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
.insert("s_k", "0", "seqlen_k, 0 means equal to s") .insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k") .insert("d", "128", "head dim for q, k")
.insert("d_v", "0", "head dim for v, 0 means equal to d") .insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale_s", .insert("scale_s",
"0", "0",
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n" "scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
...@@ -60,18 +60,24 @@ auto create_args(int argc, char* argv[]) ...@@ -60,18 +60,24 @@ auto create_args(int argc, char* argv[])
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
.insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.")
.insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.")
.insert( .insert("squant",
"squant", "auto",
"0", "if using static quantization fusion or not. auto: fp8 will default use squant, "
"if using static quantization fusion or not. 0: original flow(not prefered)\n" "other will not\n"
"1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,\n" "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
"scale_o according to range_q, range_k, range_v, range_p, range_o") "P and O.\n"
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
"range_p, range_o")
.insert("iperm", .insert("iperm",
"1", "1",
"permute input\n" "permute input\n"
"if true, will be b*h*s*d, else b*s*h*d") "if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "1", "permute output") .insert("operm", "1", "permute output")
.insert("bias", "0", "add bias or not") .insert("bias",
"n",
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("mask", .insert("mask",
"0", "0",
...@@ -88,8 +94,11 @@ auto create_args(int argc, char* argv[]) ...@@ -88,8 +94,11 @@ auto create_args(int argc, char* argv[])
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("lse", "0", "0 not store lse, 1 store lse") .insert("lse", "0", "0 not store lse, 1 store lse")
.insert("kname", "0", "if set to 1 will print kernel name") .insert("kname", "0", "if set to 1 will print kernel name")
.insert( .insert("init",
"init", "1", "init method. 0:random int, 1:random float, 2:trig float, 3:quantization") "uf",
"init method. ui, uniform random int, ni, normalized random int\n"
"uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, "
"quantization")
.insert("seed", .insert("seed",
"11939", "11939",
"random seed used for initializing input tensors. 0 for " "random seed used for initializing input tensors. 0 for "
...@@ -103,7 +112,7 @@ auto create_args(int argc, char* argv[]) ...@@ -103,7 +112,7 @@ auto create_args(int argc, char* argv[])
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataType>
auto get_elimit(int /*init_method*/) auto get_elimit(std::string /*init_method*/)
{ {
double rtol = 1e-3; double rtol = 1e-3;
double atol = 1e-3; double atol = 1e-3;
...@@ -111,9 +120,15 @@ auto get_elimit(int /*init_method*/) ...@@ -111,9 +120,15 @@ auto get_elimit(int /*init_method*/)
} }
template <> template <>
auto get_elimit<ck_tile::bf16_t>(int init_method) auto get_elimit<ck_tile::bf16_t>(std::string init_method)
{ {
if(init_method == 0) if(init_method == "ui" || init_method == "ni")
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
else if(init_method == "nf")
{ {
double rtol = 1e-2; double rtol = 1e-2;
double atol = 1e-2; double atol = 1e-2;
...@@ -128,9 +143,9 @@ auto get_elimit<ck_tile::bf16_t>(int init_method) ...@@ -128,9 +143,9 @@ auto get_elimit<ck_tile::bf16_t>(int init_method)
} }
template <> template <>
auto get_elimit<ck_tile::fp8_t>(int init_method) auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{ {
if(init_method == 0) if(init_method == "ui" || init_method == "ni")
{ {
unsigned max_rounding_point_distance = 0; unsigned max_rounding_point_distance = 0;
double atol = 2e-3; double atol = 2e-3;
...@@ -153,7 +168,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -153,7 +168,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
if(nhead_k == 0) if(nhead_k < 0)
nhead_k = nhead; nhead_k = nhead;
if(nhead % nhead_k != 0) if(nhead % nhead_k != 0)
...@@ -164,11 +179,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -164,11 +179,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t seqlen_q = arg_parser.get_int("s"); ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
if(seqlen_k == 0) if(seqlen_k < 0)
seqlen_k = seqlen_q; seqlen_k = seqlen_q;
ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v == 0) if(hdim_v < 0)
hdim_v = hdim_q; hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
...@@ -178,15 +193,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -178,15 +193,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(scale_s == .0f) if(scale_s == .0f)
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ? scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
bool squant = arg_parser.get_bool("squant"); std::string squant_str = arg_parser.get_str("squant");
if constexpr(!std::is_same_v<DataType, ck_tile::fp8_t>) bool squant = [&]() {
{ if(squant_str == "auto")
if(squant)
{ {
std::cerr << "static quantization only support fp8 for now" << std::endl; if(data_type == "fp8")
return false; return true;
else
return false;
} }
} else
return atoi(squant_str.c_str()) != 0 ? true : false;
}();
float range_q = arg_parser.get_float("range_q"); float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k"); float range_k = arg_parser.get_float("range_k");
...@@ -208,12 +226,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -208,12 +226,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::string vlayout = arg_parser.get_str("vlayout"); std::string vlayout = arg_parser.get_str("vlayout");
bool use_bias = arg_parser.get_bool("bias");
bool lse = arg_parser.get_bool("lse"); bool lse = arg_parser.get_bool("lse");
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
int init_method = arg_parser.get_int("init"); std::string init_method = arg_parser.get_str("init");
std::optional<uint32_t> seed = arg_parser.get_uint32("seed"); std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
if(*seed == 0) if(*seed == 0)
{ {
...@@ -295,12 +313,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -295,12 +313,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<VDataType> v_host( ck_tile::HostTensor<VDataType> v_host(
is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k));
// use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host
// will not be used for verification at all (but will be copied to device anyway).
ck_tile::HostTensor<BiasDataType> bias_host( ck_tile::HostTensor<BiasDataType> bias_host(
use_bias bias.type == bias_enum::elementwise_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<SaccDataType> alibi_slope_host(
bias.type == bias_enum::alibi
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
: std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1});
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q] // self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q} lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
...@@ -309,28 +333,43 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -309,28 +333,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host( ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
if(init_method == 0) if(init_method == "ui" || init_method == "0")
{ {
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host); ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host); ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host); ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host); ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
} }
else if(init_method == 1) else if(init_method == "ni")
{
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
}
else if(init_method == "uf" || init_method == "1")
{ {
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host); ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host); ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host); ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host); ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
} }
else if(init_method == 2) else if(init_method == "nf")
{
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
}
else if(init_method == "tf" || init_method == "2")
{ {
ck_tile::FillTrigValue<QDataType>{}(q_host); ck_tile::FillTrigValue<QDataType>{}(q_host);
ck_tile::FillTrigValue<KDataType>{}(k_host); ck_tile::FillTrigValue<KDataType>{}(k_host);
ck_tile::FillTrigValue<VDataType>{}(v_host); ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host); ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
} }
else if(init_method == 3) // suitable for fp8 quantization else if(init_method == "ufq" || init_method == "uf:q" ||
init_method == "3") // suitable for fp8 quantization
{ {
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host); ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
...@@ -341,6 +380,24 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -341,6 +380,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
// Assume bias is in [-1.f, 1.f] in original fp32 // Assume bias is in [-1.f, 1.f] in original fp32
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host); ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
} }
if(bias.type == bias_enum::alibi)
{
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
assert(slopes.size() == nhead);
if(bias.rank_info == 0)
{
// alibi in 1*h
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin());
}
else
{
// alibi in b*h
for(auto i_b = 0; i_b < batch; i_b++)
{
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead);
}
}
}
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
...@@ -350,6 +407,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -350,6 +407,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data()); q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data()); k_buf.ToDevice(k_host.data());
...@@ -357,6 +415,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -357,6 +415,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias_buf.ToDevice(bias_host.data()); bias_buf.ToDevice(bias_host.data());
seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data()); seqstart_k.ToDevice(seqstart_k_host.data());
alibi_slope_buf.ToDevice(alibi_slope_host.data());
// clang-format off // clang-format off
auto layout_str = [&](bool permute){ auto layout_str = [&](bool permute){
...@@ -372,9 +431,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -372,9 +431,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", bias:" << use_bias << ", lse:" << lse << ", squant:" << squant << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout
<< ", mask:" << mask << ", v:" << vlayout << std::flush; << std::flush;
auto fmha_traits = fmha_fwd_traits{hdim_q, auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v, hdim_v,
...@@ -382,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -382,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
mode == mode_enum::group, mode == mode_enum::group,
is_v_rowmajor, is_v_rowmajor,
mask.type, mask.type,
use_bias, bias.type,
lse, lse,
squant}; squant};
...@@ -441,7 +500,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -441,7 +500,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
return fmha_fwd_args{q_buf.GetDeviceBuffer(), return fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(), bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
...@@ -461,7 +521,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -461,7 +521,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_q, stride_q,
stride_k, stride_k,
stride_v, stride_v,
stride_bias, bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias,
stride_o, stride_o,
nhead_stride_q, nhead_stride_q,
nhead_stride_k, nhead_stride_k,
...@@ -564,8 +625,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -564,8 +625,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{}, ck_tile::identity{},
ck_tile::scales(scale_s)); ck_tile::scales(scale_s));
if(use_bias) if(bias.type == bias_enum::elementwise_bias)
{ {
// elementwise bias
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k}); ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
// clang-format off // clang-format off
if(i_perm) if(i_perm)
...@@ -582,6 +644,51 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -582,6 +644,51 @@ bool run(const ck_tile::ArgParser& arg_parser)
SMPLComputeDataType>( SMPLComputeDataType>(
s_host_ref, bias_host_ref, s_host_ref); s_host_ref, bias_host_ref, s_host_ref);
} }
else if(bias.type == bias_enum::alibi)
{
// alibi construct elementwise bias to verify
auto alibi_host = [&]() {
if(mask.type != mask_enum::no_mask)
{
return ck_tile::make_alibi_from_lr_mask<SaccDataType, true>(
0,
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
}
else
{
return ck_tile::Alibi<SaccDataType, true>{
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL};
}
}();
ck_tile::HostTensor<SaccDataType> alibi_bias_host_ref(
{nhead, real_seqlen_q, real_seqlen_k});
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
for(auto i_h = 0; i_h < nhead; i_h++)
{
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
alibi_host.slope = current_slope;
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
{
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
{
SaccDataType pixel = 0;
alibi_host.update(pixel, i_r, i_c);
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
}
}
}
// [nhead, real_seqlen_q, real_seqlen_k]
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
SaccDataType,
SMPLComputeDataType,
SMPLComputeDataType>(
s_host_ref, alibi_bias_host_ref, s_host_ref);
}
if(mask.type == mask_enum::no_mask) if(mask.type == mask_enum::no_mask)
{ {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp" #include "mask.hpp"
#include "bias.hpp"
#include <type_traits> #include <type_traits>
template <typename DataType> template <typename DataType>
...@@ -86,7 +87,7 @@ struct fmha_fwd_args ...@@ -86,7 +87,7 @@ struct fmha_fwd_args
const void* q_ptr; const void* q_ptr;
const void* k_ptr; const void* k_ptr;
const void* v_ptr; const void* v_ptr;
const void* bias_ptr; const void* bias_ptr; // bias or alibi_slope pointer
void* lse_ptr; void* lse_ptr;
void* o_ptr; void* o_ptr;
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
...@@ -106,7 +107,7 @@ struct fmha_fwd_args ...@@ -106,7 +107,7 @@ struct fmha_fwd_args
ck_tile::index_t stride_q; ck_tile::index_t stride_q;
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
ck_tile::index_t stride_v; ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_k;
...@@ -219,7 +220,7 @@ template <ck_tile::index_t HDim_, ...@@ -219,7 +220,7 @@ template <ck_tile::index_t HDim_,
bool kIsVLayoutRowMajor_, bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_, ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_, typename FmhaMask_,
bool kHasBias_, ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_, bool kStoreLse_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kPadS_, bool kPadS_,
...@@ -240,7 +241,7 @@ struct fmha_fwd_traits_ ...@@ -240,7 +241,7 @@ struct fmha_fwd_traits_
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>; using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_; static constexpr bool kPadS = kPadS_;
...@@ -261,7 +262,7 @@ struct fmha_fwd_traits ...@@ -261,7 +262,7 @@ struct fmha_fwd_traits
bool is_group_mode; bool is_group_mode;
bool is_v_rowmajor; bool is_v_rowmajor;
mask_enum mask_type; mask_enum mask_type;
bool has_bias; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse; bool has_lse;
bool do_fp8_static_quant; bool do_fp8_static_quant;
// TODO: padding check is inside this api // TODO: padding check is inside this api
......
...@@ -40,6 +40,19 @@ MASK_MAP = { ...@@ -40,6 +40,19 @@ MASK_MAP = {
"generic" : "FmhaMasks::GenericMask" "generic" : "FmhaMasks::GenericMask"
} }
BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
}
# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
}
MODE_MAP = { MODE_MAP = {
"batch" : "false", "batch" : "false",
"group" : "true" "group" : "true"
...@@ -173,7 +186,7 @@ MASK_SIMPLIFIED_CHECK_MAP = { ...@@ -173,7 +186,7 @@ MASK_SIMPLIFIED_CHECK_MAP = {
"s_mask" : "t.mask_type != mask_enum::no_mask", "s_mask" : "t.mask_type != mask_enum::no_mask",
} }
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a); return fmha_fwd_<trait_>(s, a);
...@@ -213,7 +226,7 @@ class FmhaFwdApiTrait: ...@@ -213,7 +226,7 @@ class FmhaFwdApiTrait:
bk0blen : int bk0blen : int
vlayout : str vlayout : str
mask : str mask : str
bias : str # true/false bias : str #
lse : str # lse : str #
squant : str # squant : str #
spad : str spad : str
...@@ -241,8 +254,8 @@ class FmhaFwdApiTrait: ...@@ -241,8 +254,8 @@ class FmhaFwdApiTrait:
def skcheck(self) -> str: def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async': if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k % {self.bn0} == 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']: elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0' else : return f'a.seqlen_k % {self.bn0} == 0'
...@@ -297,7 +310,7 @@ class FmhaFwdPipeline: ...@@ -297,7 +310,7 @@ class FmhaFwdPipeline:
pn = pad_name() pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}' n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}' if pn != '' : n += f'_{pn}'
if self.F_bias == 't' : n += '_bias' if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_mask[0:2] == 's_': if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask' if self.F_mask == 's_mask': n += f'_mask'
else: else:
...@@ -332,7 +345,8 @@ class FmhaFwdApiPool: ...@@ -332,7 +345,8 @@ class FmhaFwdApiPool:
if_k = 'if' if k == 0 else 'else if' if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
...@@ -400,7 +414,7 @@ class FmhaFwdKernel: ...@@ -400,7 +414,7 @@ class FmhaFwdKernel:
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BOOL_MAP[self.F_pipeline.F_bias], F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy, F_occupancy = self.F_tile.F_occupancy,
...@@ -454,7 +468,9 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ ...@@ -454,7 +468,9 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1) '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1)
} }
else: else:
return None return None
...@@ -472,7 +488,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw ...@@ -472,7 +488,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
squant = 't' if dtype == 'fp8' else 'f' squant = 't' if dtype == 'fp8' else 'f'
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
if hdim == 256: if hdim == 256:
# if True: # if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
...@@ -490,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw ...@@ -490,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# no need lse kernels # no need lse kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask))
else: else:
assert False assert False
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -149,11 +149,9 @@ struct mask_info ...@@ -149,11 +149,9 @@ struct mask_info
return tmp; return tmp;
} }
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);
return os;
}
}; };
inline std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);
return os;
}
...@@ -17,7 +17,7 @@ for perm in 0 1 ; do ...@@ -17,7 +17,7 @@ for perm in 0 1 ; do
for vlayout in "r" "c" ; do for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do for lse in 0 1 ; do
for bias in 0 1 ; do for bias in "n" "e" "a"; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
...@@ -27,6 +27,7 @@ $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$b ...@@ -27,6 +27,7 @@ $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$b
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done done
done done
...@@ -37,9 +38,11 @@ done ...@@ -37,9 +38,11 @@ done
done done
for perm in 0 1 ; do for perm in 0 1 ; do
for bias in 0 1 ; do for bias in "n" "e" "a" ; do
for b in 1 2 ; do for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS $EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
done done
done done
done done
done
...@@ -4,12 +4,19 @@ ...@@ -4,12 +4,19 @@
#pragma once #pragma once
#include "ck/config.h" #include "ck/config.h"
#include "ck/utility/env.hpp"
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
#endif #endif
// environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// to do: add various levels of logging with CK_LOG_LEVEL
#define CK_TIME_KERNEL 1 #define CK_TIME_KERNEL 1
// constant address space for kernel parameter // constant address space for kernel parameter
...@@ -225,17 +232,17 @@ ...@@ -225,17 +232,17 @@
// workaround: compiler issue on gfx908 // workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1 #define CK_WORKAROUND_SWDEV_388832 1
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
// denorm test fix, required to work around dissue // denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX #ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0 #define CK_WORKAROUND_DENORM_FIX 0
#else #else
// enable only on MI200 // enable only for gfx90a
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX #endif // CK_WORKAROUND_DENORM_FIX
// set flag to 1 to build deprecated instances
#define CK_BUILD_DEPRECATED 1
namespace ck { namespace ck {
enum struct InMemoryDataOperationEnum enum struct InMemoryDataOperationEnum
......
...@@ -65,20 +65,20 @@ inline bool is_lds_direct_load_supported() ...@@ -65,20 +65,20 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
} }
inline bool is_navi1_supported() inline bool is_gfx101_supported()
{ {
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
ck::get_device_name() == "gfx1012"; ck::get_device_name() == "gfx1012";
} }
inline bool is_navi2_supported() inline bool is_gfx103_supported()
{ {
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" || return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" || ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" ||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
} }
inline bool is_navi3_supported() inline bool is_gfx11_supported()
{ {
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include <set>
#include <vector>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/flush_icache.hpp"
namespace ck {
namespace utility {
template <typename Argument>
struct RotatingMemWrapper
{
using ADataType = decltype(Argument::p_a_grid);
using BDataType = decltype(Argument::p_b_grid);
RotatingMemWrapper() = delete;
RotatingMemWrapper(Argument& arg_,
std::size_t rotating_count_,
std::size_t size_a_,
std::size_t size_b_)
: arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
{
p_a_grids.push_back(arg.p_a_grid);
p_b_grids.push_back(arg.p_b_grid);
for(size_t i = 1; i < rotating_count; i++)
{
{
void* pADeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
const_cast<void*>(p_a_grids[0]),
size_a_,
hipMemcpyDeviceToDevice));
p_a_grids.push_back(pADeviceBuf);
}
{
void* pBDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
const_cast<void*>(p_b_grids[0]),
size_b_,
hipMemcpyDeviceToDevice));
p_b_grids.push_back(pBDeviceBuf);
}
}
}
void Next()
{
if(rotating_count > 1)
{
std::size_t idx = iter++ % rotating_count;
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
}
}
void Print()
{
std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
<< ", rotating_count: " << rotating_count << "}" << std::endl;
}
~RotatingMemWrapper()
{
if(rotating_count > 1)
{
// restore ptr
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
// free device mem
for(size_t i = 1; i < rotating_count; i++)
{
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
}
}
}
private:
Argument& arg;
std::size_t iter = 0;
std::size_t rotating_count = 1;
std::size_t size_a = 0;
std::size_t size_b = 0;
std::vector<const void*> p_a_grids;
std::vector<const void*> p_b_grids;
};
inline void flush_icache()
{
hipDeviceProp_t deviceProps;
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
hip_check_error(hipGetLastError());
}
// if TimePrePress == false, return time does not include preprocess's time
template <bool TimePreprocess, typename Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
PreProcessFunc preprocess,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args& args)
{
#if CK_TIME_KERNEL
#define MEDIAN 1
if(stream_config.time_kernel_)
{
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up %d times\n", stream_config.cold_niters_);
}
// warm up
for(int i = 0; i < stream_config.cold_niters_; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
hip_check_error(hipGetLastError());
}
const int nrepeat = stream_config.nrepeat_;
if(nrepeat == 0)
{
return 0.0;
}
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
printf("Start running %d times...\n", nrepeat);
}
#if MEDIAN
std::set<float> times;
#else
float total_time = 0;
#endif
for(int i = 0; i < nrepeat; ++i)
{
if constexpr(!TimePreprocess)
{
preprocess();
}
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time
if constexpr(TimePreprocess)
{
preprocess();
}
// run real kernel
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
hip_check_error(hipGetLastError());
// end real kernel
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float cur_time = 0;
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
times.insert(cur_time);
#else
total_time += cur_time;
#endif
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
printf("args.p_a_grid: %p, args.p_b_grid:%p\n",
static_cast<const void*>(args.p_a_grid),
static_cast<const void*>(args.p_b_grid));
}
}
#if MEDIAN
auto mid = times.begin();
std::advance(mid, (nrepeat - 1) / 2);
if(nrepeat % 2 == 1)
{
return *mid;
}
else
{
auto mid_next = mid;
std::advance(mid_next, 1);
return (*mid + *mid_next) / 2;
}
#else
return total_time / nrepeat;
#endif
}
else
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
hip_check_error(hipGetLastError());
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
hip_check_error(hipGetLastError());
return 0;
#endif
}
} // namespace utility
} // namespace ck
...@@ -20,18 +20,19 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -20,18 +20,19 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if(stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", {
__func__, printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
grid_dim.x, __func__,
grid_dim.y, grid_dim.x,
grid_dim.z, grid_dim.y,
block_dim.x, grid_dim.z,
block_dim.y, block_dim.x,
block_dim.z); block_dim.y,
block_dim.z);
printf("Warm up %d times\n", stream_config.cold_niters_);
#endif printf("Warm up %d times\n", stream_config.cold_niters_);
}
// warm up // warm up
for(int i = 0; i < stream_config.cold_niters_; ++i) for(int i = 0; i < stream_config.cold_niters_; ++i)
{ {
...@@ -40,9 +41,10 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -40,9 +41,10 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
} }
const int nrepeat = stream_config.nrepeat_; const int nrepeat = stream_config.nrepeat_;
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
printf("Start running %d times...\n", nrepeat); {
#endif printf("Start running %d times...\n", nrepeat);
}
hipEvent_t start, stop; hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start)); hip_check_error(hipEventCreate(&start));
...@@ -93,18 +95,19 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -93,18 +95,19 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
if(stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", {
__func__, printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
grid_dim.x, __func__,
grid_dim.y, grid_dim.x,
grid_dim.z, grid_dim.y,
block_dim.x, grid_dim.z,
block_dim.y, block_dim.x,
block_dim.z); block_dim.y,
block_dim.z);
printf("Warm up %d times\n", stream_config.cold_niters_);
#endif printf("Warm up %d times\n", stream_config.cold_niters_);
}
// warm up // warm up
for(int i = 0; i < stream_config.cold_niters_; ++i) for(int i = 0; i < stream_config.cold_niters_; ++i)
{ {
...@@ -114,9 +117,10 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -114,9 +117,10 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
} }
const int nrepeat = stream_config.nrepeat_; const int nrepeat = stream_config.nrepeat_;
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
printf("Start running %d times...\n", nrepeat); {
#endif printf("Start running %d times...\n", nrepeat);
}
hipEvent_t start, stop; hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start)); hip_check_error(hipEventCreate(&start));
......
...@@ -13,4 +13,7 @@ struct StreamConfig ...@@ -13,4 +13,7 @@ struct StreamConfig
int log_level_ = 0; int log_level_ = 0;
int cold_niters_ = 5; int cold_niters_ = 5;
int nrepeat_ = 50; int nrepeat_ = 50;
bool flush_cache = false;
int rotating_count = 1;
}; };
...@@ -140,8 +140,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -140,8 +140,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride; using Base::AMmaKStride;
using Base::BMmaKStride; using Base::BMmaKStride;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / (4 * warpSize / BlockSize), 32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
static constexpr index_t PrefetchStages = static constexpr index_t PrefetchStages =
FullMemBandPrefetchStages >= 2 FullMemBandPrefetchStages >= 2
...@@ -631,8 +633,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -631,8 +633,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / (4 * warpSize / BlockSize), 32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
static constexpr index_t PrefetchStages = static constexpr index_t PrefetchStages =
FullMemBandPrefetchStages >= 2 FullMemBandPrefetchStages >= 2
......
...@@ -184,19 +184,22 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -184,19 +184,22 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr auto ds_read_b_issue_cycle = constexpr auto ds_read_b_issue_cycle =
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate = constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 8 + ds_read_a_issue_cycle - 1) / ds_read_a_issue_cycle; (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate = constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 8 + ds_read_b_issue_cycle - 1) / ds_read_b_issue_cycle; (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1 // stage 1
// Separate this part? // Separate this part?
constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
sizeof(ComputeDataType) / sizeof(BDataType) // sizeof(ComputeDataType) / sizeof(BDataType)
? sizeof(ComputeDataType) / sizeof(ADataType) // ? sizeof(ComputeDataType) / sizeof(ADataType)
: sizeof(ComputeDataType) / sizeof(BDataType); // : sizeof(ComputeDataType) / sizeof(BDataType);
constexpr auto num_mfma_stage1 = constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
num_mfma_inst - num_mfma_per_ds_read * (num_ds_read_inst_a / ds_read_a_mfma_rate +
num_ds_read_inst_b / ds_read_b_mfma_rate);
constexpr auto num_mfma_per_issue = constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
...@@ -226,16 +229,36 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -226,16 +229,36 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
}); });
// stage 2 // stage 2
static_for<0, num_ds_read_inst_a / ds_read_a_mfma_rate, 1>{}([&](auto i) { static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
ignore = i; if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read ds_read_a_mfma_rate)
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_ds_read, 0); // MFMA {
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); });
static_for<0, num_ds_read_inst_b / ds_read_b_mfma_rate, 1>{}([&](auto i) { static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
ignore = i; if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read ds_read_b_mfma_rate)
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_ds_read, 0); // MFMA {
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); });
} }
......
...@@ -194,9 +194,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -194,9 +194,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr auto ds_read_b_issue_cycle = constexpr auto ds_read_b_issue_cycle =
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate = constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 8 + ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate = constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 8 + ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1); constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1); constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
......
...@@ -41,7 +41,8 @@ template <typename ThreadGroup, ...@@ -41,7 +41,8 @@ template <typename ThreadGroup,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector, index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags, typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags> typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r2 struct ThreadGroupTensorSliceTransfer_v7r2
{ {
static constexpr index_t nDim = static constexpr index_t nDim =
...@@ -100,7 +101,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2 ...@@ -100,7 +101,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
...@@ -117,29 +118,33 @@ struct ThreadGroupTensorSliceTransfer_v7r2 ...@@ -117,29 +118,33 @@ struct ThreadGroupTensorSliceTransfer_v7r2
} }
} }
template <typename SrcBuffers> template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) __device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_descs, src_bufs); threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
} }
} }
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename DstBuffers> template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) __device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value) if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs); threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
else else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs)); threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
} }
} }
...@@ -206,7 +211,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2 ...@@ -206,7 +211,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
SrcScalarPerVector, SrcScalarPerVector,
DstScalarPerVector, DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags>; ThreadTransferDstResetCoordinateAfterRunFlags,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmTileLoopKernelArguments
{
__host__ __device__
GroupedGemmTileLoopKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()) ck::is_gfx103_supported() || ck::is_gfx11_supported())
{ {
bool pass = true; bool pass = true;
pass = pass && arg.K_ % K1 == 0; pass = pass && arg.K_ % K1 == 0;
......
...@@ -587,30 +587,31 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -587,30 +587,31 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
BatchStrideD1s, BatchStrideD1s,
BatchStrideE1} BatchStrideE1}
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", " {
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl; std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", " << a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
<< b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) << ", " << b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
<< d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl; std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0)
std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", " << ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
<< b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{" << b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", " std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}" << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", "
<< std::endl; << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}"
std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", " << std::endl;
<< e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
#endif << e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>; using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment