Unverified Commit 909f519c authored by Harisankar Sadasivan's avatar Harisankar Sadasivan Committed by GitHub
Browse files

Merge branch 'develop' into universal_streamk

parents 406fa265 3bb0fe6c
...@@ -93,6 +93,8 @@ struct fmha_fwd_args ...@@ -93,6 +93,8 @@ struct fmha_fwd_args
const void* v_ptr; const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr; void* rand_val_ptr;
void* lse_acc_ptr;
void* o_acc_ptr;
void* lse_ptr; void* lse_ptr;
void* o_ptr; void* o_ptr;
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
...@@ -106,6 +108,7 @@ struct fmha_fwd_args ...@@ -106,6 +108,7 @@ struct fmha_fwd_args
ck_tile::index_t hdim_v; ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q; ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k; ck_tile::index_t nhead_k;
ck_tile::index_t num_splits;
float scale_s; float scale_s;
float scale_p; float scale_p;
float scale_o; float scale_o;
...@@ -114,6 +117,7 @@ struct fmha_fwd_args ...@@ -114,6 +117,7 @@ struct fmha_fwd_args
ck_tile::index_t stride_v; ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval; ck_tile::index_t stride_randval;
ck_tile::index_t stride_o_acc;
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_k;
...@@ -121,6 +125,8 @@ struct fmha_fwd_args ...@@ -121,6 +125,8 @@ struct fmha_fwd_args
ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
...@@ -128,7 +134,11 @@ struct fmha_fwd_args ...@@ -128,7 +134,11 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
ck_tile::index_t window_size_left; ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
...@@ -234,6 +244,176 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -234,6 +244,176 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
template <typename Kernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.max_seqlen_q,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
dim3 grids =
Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits);
return ck_tile::make_tuple(kargs, grids);
}
template <typename Kernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel argumentszs
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.max_seqlen_q,
args.seqlen_q,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.batch_stride_o,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
}();
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_, template <ck_tile::index_t HDim_,
typename DataType_, typename DataType_,
...@@ -282,6 +462,40 @@ struct fmha_fwd_traits_ ...@@ -282,6 +462,40 @@ struct fmha_fwd_traits_
template <typename Traits_> template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN1_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kPadS_,
bool kPadDv_>
struct fmha_fwd_splitkv_combine_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_combine_get_name_();
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct fmha_fwd_traits struct fmha_fwd_traits
{ {
...@@ -298,3 +512,4 @@ struct fmha_fwd_traits ...@@ -298,3 +512,4 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
...@@ -3,1214 +3,62 @@ ...@@ -3,1214 +3,62 @@
# generate kernel instances to speed up compilation # generate kernel instances to speed up compilation
import argparse import argparse
import itertools from enum import IntEnum
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional
from dataclasses import dataclass
import copy
import fnmatch
DTYPE_MAP = { from codegen.cmake_config import *
"fp16": "ck_tile::fp16_t", from codegen.ops import (
"bf16": "ck_tile::bf16_t", fmha_fwd,
"fp8" : "ck_tile::fp8_t" fmha_fwd_splitkv,
} fmha_bwd
)
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
}
MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
}
MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
"generic" : "FmhaMasks::GenericMask"
}
BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
}
# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
}
MODE_MAP = {
"batch" : "false",
"group" : "true"
}
LAYOUT_MAP = {
"row" : "true",
"col" : "false"
}
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t" : "true",
"f" : "false"
}
TILE_PARTITIONER_MAP = {
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
}
GEN_DIR = "" # in Cmake, have to generate files in same folder
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>;
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx},
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_mask_{F_idx},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<{F_tile_partitioner}<fmha_shape_{F_idx}>,
fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
template<>
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
FMHA_FWD_API="""
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_inner_dispatch}
}}
"""
MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
}
MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
}
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
def get_mask_map(mask : str):
if mask == "generic":
return MASK_MAP
elif mask == "simplified":
return MASK_SIMPLIFIED_MAP
else:
assert False
return None
def get_mask_check_map(mask : str):
if mask == "generic":
return MASK_CHECK_MAP
elif mask == "simplified":
return MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0blen : int
vlayout : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bk0blen} == 0'
else: assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bk0blen} == 0'
else: assert False
@dataclass
class FmhaFwdPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse'
if self.F_dropout == 't' : n += '_dropout'
if self.F_squant == 't' : n += '_squant'
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm : int # number of warps along q seqlen (block warps)
F_rn : int # number of warps along k seqlen(not used)
F_rk : int # number of warps along gemm-k(not used)
F_wm : int # warp size along m (warp size)
F_wn : int # warp size along n
F_wk : int # warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\
f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdKernel:
direction : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
def get_tp(self) -> str:
if self.F_mode == 'group':
return 'hbs'
else:
return 'shb'
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0blen = self.F_tile.F_bk0blen,
F_rm = self.F_tile.F_rm,
F_rn = self.F_tile.F_rn,
F_rk = self.F_tile.F_rk,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()])
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]:
if direction == 'fwd':
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1)
}
else:
return None
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
if receipt == 1:
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()): class HandlerId(IntEnum):
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) LIST_BLOBS = 0
if d == None: WRITE_BLOBS = 1
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
k = FmhaFwdKernel(direction=direction,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen) handlers = {
'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs),
BWD_DQDKDV_PIPELINE_MAP = { 'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs),
"ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", 'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs),
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR",
}
BWD_DQDKDV_PIPELINE_ENUM_MAP = {
"ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR",
"qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS",
"ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR",
} }
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "fmha_bwd.hpp"
"""
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps0_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps1_{F_idx},
fmha_warp_tile_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile_{F_idx}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_bias},
{F_dbias},
false,
{F_dropout},
false,
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
fmha_bwd_shape_{F_idx},
{F_mode},
fmha_mask_{F_idx},
fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = {F_pipeline}<
fmha_bwd_pipeline_problem_{F_idx}>;
using fmha_bwd_dk_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
false, false>>;
using fmha_bwd_dv_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
false, false>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
template<>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
template<>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetName();
}}
"""
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_API="""
#include <iostream>
template<typename dot_do_o_trait_, typename dq_dk_dv_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
);
}}
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_inner_dispatch}
}}
"""
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_>(s, a);
return r;
}}
"""
@dataclass
class FmhaBwdDQDKDVApiTrait:
pipeline : str
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
@property
def name(self) -> str:
return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
def scheck(self, spad1 : str) -> str:
if self.mode == 'group':
return 'true' # always support
elif self.spad == 't' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} != 0'
elif self.spad == 'f' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize
else: # self.skpad == 'f' and skpad1 == 'f'
return f'a.seqlen_q % 256 == 0' # BlockSize
@property
def skcheck(self) -> str:
if self.mode == 'group':
return 'true' # always support
elif self.skpad == 't':
return f'a.seqlen_k % {self.bn0} != 0'
else:
return f'a.seqlen_k % {self.bn0} == 0'
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
else : return f'a.hdim_q % {self.bhdq} == 0'
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
else : return f'a.hdim_v % {self.bhdv} == 0'
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = dict()
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.dq_dk_dv_pool.keys():
self.dq_dk_dv_pool[trait.dtype] = dict()
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
traits=self.dq_dk_dv_pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
for spad1 in ["t", "f"]:
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
continue
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
# GEMM0: Q@K=S^T
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
# Is it necessary to distinguish between K0~K4?
@dataclass
class FmhaBwdDQDKDVTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along gemm0 unroll(F_bhdq)
F_bk1 : int # tile size along gemm1 unroll(F_bm0)
F_bk2 : int # tile size along gemm2 unroll(F_bhdv)
F_bk3 : int # tile size along gemm3 unroll(F_bm0)
F_bk4 : int # tile size along gemm4 unroll(F_bn0)
F_bhdq : int # q head_dim
F_bhdv : int # v head_dim
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3
F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3
F_rm2 : int # number of warps along k seqlen (block warps) in gemm4
F_rn2 : int # number of warps along q seqlen (block warps) in gemm4
F_rk2 : int # number of warps along gemm-k (not used) in gemm4
F_wm : int # warp size along m (warp size)
F_wn : int # warp size along n
F_wk : int # warp size along k
F_occupancy : int # occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}"
@dataclass
class FmhaBwdDQDKDVKernel:
direction : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_bias : str #
F_dbias : str #
F_dropout : str #
F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP
F_pipeline : str
mask_impl : str
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bk1 = self.F_tile.F_bk1,
F_bk2 = self.F_tile.F_bk2,
F_bk3 = self.F_tile.F_bk3,
F_bk4 = self.F_tile.F_bk4,
F_bhdq = self.F_tile.F_bhdq,
F_bhdv = self.F_tile.F_bhdv,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_rm2 = self.F_tile.F_rm2,
F_rn2 = self.F_tile.F_rn2,
F_rk2 = self.F_tile.F_rk2,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_spad = BOOL_MAP[self.F_spad],
F_skpad = BOOL_MAP[self.F_skpad],
F_dpad = BOOL_MAP[self.F_dpad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_bias = BIAS_MAP[self.F_bias],
F_dbias = BOOL_MAP[self.F_dbias],
F_dropout = BOOL_MAP[self.F_dropout],
F_occupancy = self.F_tile.F_occupancy,
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline],
F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline])
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_dbias == 't' : n += '_dbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_dropout == 't' : n += '_dropout'
return n
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaBwdDQDKDVApiTrait:
return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bhdq=self.F_tile.F_bhdq,
bhdv=self.F_tile.F_bhdv,
mask=self.F_mask,
bias=self.F_bias,
dbias=self.F_dbias,
dropout=self.F_dropout,
spad=self.F_spad,
skpad=self.F_skpad,
dpad=self.F_dpad,
dvpad=self.F_dvpad)
# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]:
if direction == 'bwd':
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1),
"qs_ks_vr_dos"],
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
"qs_ks_vr_dos"],
'128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1),
"ks_vr"]
}
else:
return None
else:
return None
def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]:
# TODO: we don't support tuning yet, so pick up one value for pad
# support this in future
gen = list()
api_pool = FmhaBwdApiPool(mask_impl)
for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()):
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype)
if d == None:
continue
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
tile = d[hdim_str][0]
ppl = d[hdim_str][1]
hdim = int(hdim_str)
if (mode == "group") and (spad == "f" or skpad == "f"):
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
F_pipeline=ppl, mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad},
{F_dvpad},
{F_occupancy}>;
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
/* BlockSize = */ 256,
{F_hdim},
{F_mode},
fmha_bwd_dot_do_o_trait_{F_idx}>;
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
#include <iostream>
template<>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template<>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
template<>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
return k_::GetName();
}}
"""
@dataclass
class FmhaBwdOGradDotOKernel:
direction : str
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_spad : str # true/false
F_dvpad : str #
F_mode : str # value from MODE_MAP
F_occupancy : int
@property
def template(self) -> str:
return FMHA_BWD_KERNEL_HEADER + \
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_spad = BOOL_MAP[self.F_spad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_mode = MODE_MAP[self.F_mode],
F_occupancy = self.F_occupancy)
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
return n
@property
def filename(self) -> str:
return self.name + ".cpp"
def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
gen = list()
for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()):
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype)
if d == None:
continue
for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]):
hdim = int(hdim_str)
if (mode == "group" and spad == "f"):
continue
k = FmhaBwdOGradDotOKernel(direction=direction+"_dot_do_o", F_idx=0, F_hdim=hdim, F_dtype=dtype,
F_spad=spad, F_dvpad=dvpad, F_mode=mode,
F_occupancy=get_occupancy(dtype, hdim))
gen.append(k)
return gen
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None:
if output_dir is None: if output_dir is None:
output_dir = Path(__file__).parent output_dir = Path(__file__).parent
else: else:
output_dir = Path(output_dir) / GEN_DIR output_dir = Path(output_dir) / GEN_DIR
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
if direction == 'fwd':
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) for api in api_list:
for kernel in kernels: handler = handlers[api][HandlerId.WRITE_BLOBS]
write_single_fwd_kernel(kernel, output_dir) handler(output_dir, kernel_filter, receipt, mask_impl)
write_fwd_api(api_pool, output_dir)
else:
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
write_bwd_api(api_pool, output_dir)
# list all the files that will be generated # list all the files that will be generated
def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None: def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
assert output_file is not None assert output_file is not None
file_path = Path(output_file) file_path = Path(output_file)
with file_path.open('a') as f:
if direction == 'fwd': for api in api_list:
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) handler = handlers[api][HandlerId.LIST_BLOBS]
for kernel in kernels: handler(file_path, kernel_filter, receipt, mask_impl)
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
else:
kernels = get_bwd_dot_do_o_blobs()
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="generate", prog="generate",
description="gen api for CK fmha kernel", description="gen API for CK fmha kernel",
) )
parser.add_argument( parser.add_argument(
"-d", "-d",
"--direction", "--direction", # we keep 'direction' option for backward compatibility
"-a",
"--api",
default='fwd', default='fwd',
choices=['fwd', 'bwd'],
required=False, required=False,
help="choose the direction of kernels(default: fwd)" help="supply API(s) to generate (default: fwd). separated by comma."
) )
parser.add_argument( parser.add_argument(
"-o", "-o",
...@@ -1251,7 +99,8 @@ if __name__ == "__main__": ...@@ -1251,7 +99,8 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
api_list = args.direction.split(',')
if args.list_blobs is not None: if args.list_blobs is not None:
list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
else: else:
write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
\ No newline at end of file
...@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__ #define __gfx11__
#endif #endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// buffer resource // buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
...@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) #elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__) #elif defined(__gfx11__) || defined(__gfx12__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif #endif
...@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8 #define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx11__) #elif defined(__gfx11__) || defined(__gfx12__)
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11 #define CK_USE_AMD_V_DOT4_I32_I8_GFX11
...@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_MFMA_GFX940 #define CK_USE_AMD_MFMA_GFX940
#endif #endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load // buffer load
#define CK_USE_AMD_BUFFER_LOAD 1 #define CK_USE_AMD_BUFFER_LOAD 1
......
...@@ -84,4 +84,9 @@ inline bool is_gfx11_supported() ...@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
} }
inline bool is_gfx12_supported()
{
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
}
} // namespace ck } // namespace ck
...@@ -13,6 +13,504 @@ ...@@ -13,6 +13,504 @@
namespace ck { namespace ck {
#ifdef __gfx12__
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatAcc,
typename ABlockDesc,
typename BBlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool AEnableLds = true,
bool BEnableLds = true,
bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct BlockwiseGemmWMMA
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWmma(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex()
{
if constexpr(AEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
if constexpr(BEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
return make_tuple(
Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
}
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
}
// Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
AccStride));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize());
static_assert(KPack % (A_K1 * A_KRow) == 0, "");
static_assert(KPack % (B_K1 * B_KRow) == 0, "");
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_for<0, KPerBlock / KPack, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a =
typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a =
typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
protected:
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
template <bool EnableLds>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
};
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
false>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
false>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
};
#else
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
...@@ -527,5 +1025,6 @@ struct BlockwiseGemmWMMA ...@@ -527,5 +1025,6 @@ struct BlockwiseGemmWMMA
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_; typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_; typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
}; };
#endif
} // namespace ck } // namespace ck
...@@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// sync point. // sync point.
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
{ {
#ifdef __gfx12__
asm volatile("\
s_barrier_signal -1 \n \
s_barrier_wait -1 \
" ::);
#else
asm volatile("s_barrier" ::); asm volatile("s_barrier" ::);
#endif
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
......
...@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
static constexpr auto AEnableLds_auto =
(NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
...@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
...@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
} }
else else
{ {
if(!(arg.a_kz_stride_ == 1 && if(!(arg.a_kz_stride_ == 1))
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{ {
printf("DeviceOp: Vector Access A-k check failure\n"); index_t LastK =
return false; AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6);
if(LastK % ABlockTransferSrcScalarPerVector == 0)
{
printf("DeviceOp: Vector Access A-k check failure\n");
return false;
}
} }
} }
......
...@@ -70,8 +70,9 @@ __global__ void ...@@ -70,8 +70,9 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported()) ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
bool pass = true; bool pass = true;
pass = pass && arg.K_ % K1 == 0; pass = pass && arg.K_ % K1 == 0;
......
...@@ -56,7 +56,7 @@ __global__ void ...@@ -56,7 +56,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -159,6 +159,7 @@ __global__ void ...@@ -159,6 +159,7 @@ __global__ void
ignore = O; ignore = O;
ignore = G0; ignore = G0;
ignore = G1; ignore = G1;
ignore = alpha;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx11__)) #endif // end of if (defined(__gfx11__))
...@@ -187,7 +188,7 @@ __global__ void ...@@ -187,7 +188,7 @@ __global__ void
index_t head_size, index_t head_size,
float alpha) float alpha)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -321,7 +322,7 @@ __global__ void ...@@ -321,7 +322,7 @@ __global__ void
index_t head_size, index_t head_size,
float alpha) float alpha)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return false; return false;
} }
if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" && if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
std::is_same<ADataType, double>::value)
{ {
return false; return false;
} }
......
...@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported())) ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{ {
return false; return false;
} }
......
...@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>)) is_same_v<AccDataType, int32_t>))
......
...@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} }
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported()) ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
...@@ -50,8 +50,9 @@ __global__ void ...@@ -50,8 +50,9 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
...@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout, ...@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported()) ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
...@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
static constexpr auto AEnableLds_auto = static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
? false
: true;
static constexpr auto BEnableLds_auto = static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true; (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
? false
: true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
...@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>)) is_same_v<AccDataType, int32_t>))
......
...@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -48,8 +48,9 @@ __global__ void ...@@ -48,8 +48,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -90,8 +90,9 @@ __global__ void ...@@ -90,8 +90,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -667,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -667,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported())) ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment