Commit f1793462 authored by rocking's avatar rocking
Browse files

Add rowwise dynamic quant kernel

parent b5d201dd
......@@ -6,7 +6,6 @@ FWD_DTYPE_MAP = {
"fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
}
......
......@@ -56,6 +56,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_lse},
{F_dropout},
{F_squant},
{F_rdquant},
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
......@@ -88,7 +89,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_rdquant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
......@@ -123,9 +124,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.do_fp8_rowwise_dynamic_quant == {F_rdquant})&&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_rdquant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
......@@ -149,6 +150,7 @@ class FmhaFwdApiTrait:
lse : str #
dropout : str
squant : str #
rdquant : str #
spad : str
skpad : str
dpad : str
......@@ -157,7 +159,7 @@ class FmhaFwdApiTrait:
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.dquant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
@property
def scheck(self) -> str:
......@@ -218,6 +220,7 @@ class FmhaFwdPipeline:
F_lse : str #
F_dropout : str #
F_squant : str #
F_rdquant : str #
F_mask : str # value from MASK_MAP
@property
......@@ -241,6 +244,7 @@ class FmhaFwdPipeline:
if self.F_lse == 't' : n += '_lse'
if self.F_dropout == 't' : n += '_dropout'
if self.F_squant == 't' : n += '_squant'
if self.F_rdquant == 't' : n += '_rdquant'
return n
class FmhaFwdApiPool:
......@@ -271,7 +275,7 @@ class FmhaFwdApiPool:
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_squant=BOOL_MAP[trait.squant], F_rdquant=BOOL_MAP[trait.rdquant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
......@@ -357,6 +361,7 @@ class FmhaFwdKernel:
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_rdquant = BOOL_MAP[self.F_pipeline.F_rdquant],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
......@@ -391,6 +396,7 @@ class FmhaFwdKernel:
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
rdquant=self.F_pipeline.F_rdquant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
......@@ -407,7 +413,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
elif dtype == 'fp8' or dtype == 'bf8' or dtype == 'fp8bf16':
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
......@@ -424,39 +430,39 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, 'f', 'f', mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, 'f', 'f', mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, 'f', 'f', mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8']:
# no need lse/dropout kernels
# squant
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', 't', 'f', mask))
elif dtype in ['fp8bf16']:
# rdquant
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', 'no', 'f', 'f', 'f', 't', 's_no'))
else:
assert False
return pipelines
......@@ -492,6 +498,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_rdquant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())
......
......@@ -85,6 +85,10 @@ auto create_args(int argc, char* argv[])
"P and O.\n"
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
"range_p, range_o")
.insert("rdquant",
"0",
"if using rowwise dynamic quantization fusion or not.\n"
"0: no dynamic quant. 1: apply rowwise dynamic quantization with respect Q, K and V.\n")
.insert("iperm",
"1",
"permute input\n"
......@@ -95,7 +99,7 @@ auto create_args(int argc, char* argv[])
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp8bf16")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
......@@ -176,6 +180,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
}
}
template <>
auto get_elimit<FmhaFwdFp8Bf16>(std::string init_method)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
......@@ -1580,6 +1592,10 @@ int main(int argc, char* argv[])
{
return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp8bf16")
{
return run<FmhaFwdFp8Bf16>(arg_parser) ? 0 : -2;
}
return -3;
}
......@@ -107,6 +107,22 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8>
using ODataType = ck_tile::bf8_t;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp8Bf16>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
......@@ -635,6 +651,7 @@ template <ck_tile::index_t HDim_,
bool kStoreLse_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
bool kDoFp8RowwiseDynamicQuant_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
......@@ -657,6 +674,7 @@ struct fmha_fwd_traits_
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kDoFp8RowwiseDynamicQuant = kDoFp8RowwiseDynamicQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
......@@ -789,6 +807,7 @@ struct fmha_fwd_traits
bool has_lse;
bool has_dropout;
bool do_fp8_static_quant;
bool do_fp8_rowwise_dynamic_quant;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
......@@ -804,6 +823,7 @@ struct fmha_fwd_splitkv_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
bool do_fp8_rowwise_dynamic_quant;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
......
......@@ -51,6 +51,8 @@ struct FmhaFwdKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kDoFp8RowwiseDynamicQuant =
FmhaPipeline::Problem::kDoFp8RowwiseDynamicQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
......@@ -93,7 +95,8 @@ struct FmhaFwdKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" ) + (kDoFp8RowwiseDynamicQuant ? "_rdquant" : "" );
#undef _SS_
#undef _TS_
// clang-format on
......
......@@ -46,15 +46,16 @@ struct BlockFmhaPipelineProblem
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kDoFp8RowwiseDynamicQuant = Traits::kDoFp8RowwiseDynamicQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename QDataType_,
......
......@@ -18,19 +18,21 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
bool kDoFp8RowwiseDynamicQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kDoFp8RowwiseDynamicQuant = kDoFp8RowwiseDynamicQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment