Unverified Commit 5018ac6a authored by jayhshah's avatar jayhshah Committed by GitHub
Browse files

Fp8 kernel with "in-kernel" transpose of V in producer (#1100)

* base version

* restructure pipelines, add special fp8 epilogue

* add variants

* add fp8 causal and modify dynamic tile scheduler

* better causal schedule

* maintain two schedules for non causal and causal

* removing macros

* fix regression

* clean up unneeded methods and variants

* fix mistake with NumProducerThreads

* base version

* restructure pipelines, add special fp8 epilogue

* add variants

* add fp8 causal and modify dynamic tile scheduler

* better causal schedule

* maintain two schedules for non causal and causal

* removing macros

* fix regression

* clean up unneeded methods and variants

* fix mistake with NumProducerThreads

* use seqlen traits

* add fp8 .cu files and benchmark script

* fix merge issue

* fix merge issue

* fix merge issue

* remove duplicate code

* fix regression with varseqlen

* move varseqlen init in constexpr

* fix test script

* more constexpr on varseqlen and add max offset

* add back test cases
parent c4b9015d
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import pickle
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
from flash_attn import flash_attn_qkvpacked_func
from flash_attn_interface import flash_attn_func
try:
from triton_fused_attention import attention as attention_triton
except ImportError:
attention_triton = None
try:
import xformers.ops as xops
except ImportError:
xops = None
try:
import cudnn
except ImportError:
cudnn = None
def convert_to_cudnn_type(torch_type):
if torch_type == torch.float16:
return cudnn.data_type.HALF
elif torch_type == torch.bfloat16:
return cudnn.data_type.BFLOAT16
elif torch_type == torch.float32:
return cudnn.data_type.FLOAT
elif torch_type == torch.int32:
return cudnn.data_type.INT32
elif torch_type == torch.int64:
return cudnn.data_type.INT64
elif torch_type == torch.float8_e4m3fn:
return cudnn.data_type.FP8_E4M3
elif torch_type == torch.float8_e4m3fn:
return cudnn.data_type.FP8_E5M2
else:
raise ValueError("Unsupported tensor data type.")
def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False):
b, _, _, nheads, headdim = qkv.shape
assert cudnn is not None, 'CUDNN is not available'
o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device)
o_gpu_transposed = torch.as_strided(
o_gpu,
[b, nheads, seqlen_q, headdim],
[nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],
)
stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device)
amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
graph = cudnn.pygraph(
io_data_type=convert_to_cudnn_type(qkv.dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
new_q = torch.as_strided(
qkv,
[b, nheads, seqlen_q, headdim],
[seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
storage_offset=0,
)
q = graph.tensor(
name = "Q",
dim = list(new_q.shape),
stride = list(new_q.stride()),
data_type=convert_to_cudnn_type(qkv.dtype)
)
new_k = torch.as_strided(
qkv,
[b, nheads, seqlen_k, headdim],
[seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
storage_offset=nheads * headdim,
)
k = graph.tensor(
name = "K",
dim = list(new_k.shape),
stride = list(new_k.stride()),
data_type=convert_to_cudnn_type(qkv.dtype)
)
new_v = torch.as_strided(
qkv,
[b, nheads, seqlen_k, headdim],
[seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
storage_offset=nheads * headdim * 2,
)
v = graph.tensor(
name = "V",
dim = list(new_v.shape),
stride = list(new_v.stride()),
data_type=convert_to_cudnn_type(qkv.dtype)
)
def get_default_scale_tensor():
return graph.tensor(
dim = [1, 1, 1, 1],
stride = [1, 1, 1, 1],
data_type=cudnn.data_type.FLOAT
)
default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda")
descale_q = get_default_scale_tensor()
descale_k = get_default_scale_tensor()
descale_v = get_default_scale_tensor()
descale_s = get_default_scale_tensor()
scale_s = get_default_scale_tensor()
scale_o = get_default_scale_tensor()
o, _, amax_s, amax_o = graph.sdpa_fp8(
q=q,
k=k,
v=v,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_s=descale_s,
scale_s=scale_s,
scale_o=scale_o,
is_inference=True,
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal,
name="sdpa",
)
o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())
amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())
amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())
# stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
graph.check_support()
graph.build_plans()
variant_pack = {
q: new_q,
k: new_k,
v: new_v,
descale_q: default_scale_gpu,
descale_k: default_scale_gpu,
descale_v: default_scale_gpu,
descale_s: default_scale_gpu,
scale_s: default_scale_gpu,
scale_o: default_scale_gpu,
o: o_gpu_transposed,
amax_s: amax_s_gpu,
amax_o: amax_o_gpu,
}
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
def run(*args, **kwargs):
graph.execute(variant_pack, workspace)
return o_gpu, amax_o_gpu
return run
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size, seqlen, _, nheads, d = qkv.shape
q, k, v = qkv.unbind(dim=2)
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b s h d -> (b h) d s')
softmax_scale = 1.0 / math.sqrt(d)
# Preallocate attn_weights for `baddbmm`
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
'(b h) t s -> b h t s', h=nheads)
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
return output.to(dtype=qkv.dtype)
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0
def time_fwd(func, *args, **kwargs):
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
time_f = benchmark_forward(func, *args, **kwargs)
return time_f[1].mean
torch.manual_seed(0)
repeats = 30
device = 'cuda'
# dtype = torch.float16
dtype = torch.float8_e4m3fn
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
# bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2), (4, 4224), (2, 8448), (1, 8448 * 2)]
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]
causal_vals = [False, True]
headdim_vals = [128]
dim = 2048
# dim = 256
dropout_p = 0.0
methods = (["Pytorch", "Flash3", "cuDNN"]
# + (["Triton"] if attention_triton is not None else [])
# + (["xformers.c"] if xops is not None else [])
# + (["xformers.f"] if xops is not None else [])
)
time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
torch.cuda.empty_cache()
config = (causal, headdim, batch_size, seqlen)
nheads = dim // headdim
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16, requires_grad=False) for _ in range(3)]
qkv = torch.stack([q, k, v], dim=2)
qkv = qkv.to(torch.float16)
f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
time_f[config, "Pytorch"] = f
res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
if attention_triton is not None:
q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)
scale = 1 / math.sqrt(headdim)
f = time_fwd(
attention_triton, q_transposed, k_transposed, v_transposed,
causal, scale, repeats=5, verbose=False, desc='Triton'
)
f = time_fwd(
attention_triton, q_transposed, k_transposed, v_transposed,
causal, scale, repeats=repeats, verbose=False, desc='Triton'
)
time_f[config, "Triton"] = f
res = attention_triton(
q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),
causal, scale
).half().transpose(1, 2)
torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
# out = torch.empty_like(q)
q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
# res = flash_attn_func(q, k, v, causal=causal)
# torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05)
time_f[config, "Flash3"] = f
if cudnn is not None:
qkv_fp8 = qkv.to(dtype)
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
f = time_fwd(
cudnn_spda_setup(
qkv_fp8, seqlen, seqlen,
causal=causal
),
repeats=repeats, verbose=False
)
time_f[config, "cuDNN"] = f
# res, amax_o = cudnn_spda_setup(
# qkv_fp8, seqlen, seqlen,
# causal=causal
# )()
# res = res.half()
# TODO: CUDNN has numerics issues when
# num_heads=16, dim=128, seq_len=1024, batch_size=2
# or larger sizes.
# res_cpu = res.cpu().reshape(-1)
# res_baseline_cpu = res_baseline.cpu().reshape(-1)
# print(amax_o)
# print(res)
# print(res_baseline)
# for i in range(len(res_cpu)):
# item = res_cpu[i]
# item_baseline = res_baseline_cpu[i]
# if abs(item - item_baseline) > 0.5:
# print(i)
# print(item)
# print(item_baseline)
# torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05)
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
for method in methods:
speed_f[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
time_f[config, method]
)
#print (time_f[config,method])
print(
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, "
)
# with open('flash3_attn_time.plk', 'wb') as fp:
# pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
...@@ -20,7 +20,7 @@ using namespace cute; ...@@ -20,7 +20,7 @@ using namespace cute;
template <typename Ktraits, typename Seqlen_traits> template <typename Ktraits, typename Seqlen_traits>
struct CollectiveEpilogueFwd { struct CollectiveEpilogueFwd {
using Element = typename Ktraits::Element; using Element = typename Ktraits::OutputType;
static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN; static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kHeadDim = Ktraits::kHeadDim; static constexpr int kHeadDim = Ktraits::kHeadDim;
...@@ -28,7 +28,7 @@ struct CollectiveEpilogueFwd { ...@@ -28,7 +28,7 @@ struct CollectiveEpilogueFwd {
static constexpr int kNWarps = Ktraits::kNWarps; static constexpr int kNWarps = Ktraits::kNWarps;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr bool Is_WS = kNWarps >= 12; static constexpr bool Is_WS = kNWarps >= 12;
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumCopyThreads; static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
...@@ -71,6 +71,16 @@ struct CollectiveEpilogueFwd { ...@@ -71,6 +71,16 @@ struct CollectiveEpilogueFwd {
TiledCopyOValLayout{} // Val layout TiledCopyOValLayout{} // Val layout
)); ));
// used for rmem -> smem O copy in fp8 kernel to undo column permutation
using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
Stride<_4, _32, _1, _0>>;
using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
Stride<_0, _2, Stride<_4, _1>, _8>>;
using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, Element>{},
ThreadLayoutrO{}, ValueLayoutrO{}));
using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
// Host side kernel arguments // Host side kernel arguments
struct Arguments { struct Arguments {
Element* ptr_O; Element* ptr_O;
...@@ -150,7 +160,7 @@ struct CollectiveEpilogueFwd { ...@@ -150,7 +160,7 @@ struct CollectiveEpilogueFwd {
if (get<1>(taccOcO_row(_0{})) == 0) { if (get<1>(taccOcO_row(_0{})) == 0) {
#pragma unroll #pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi)); const int row = get<0>(taccOcO_row(mi));
if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); } if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); }
} }
} }
...@@ -170,6 +180,73 @@ struct CollectiveEpilogueFwd { ...@@ -170,6 +180,73 @@ struct CollectiveEpilogueFwd {
); );
} }
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
CUTLASS_DEVICE void
store_fp8(Params const& epilogue_params,
FrgTensorO const& tOrO,
FrgTensorLSE const& lse,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
const Seqlen_traits& seqlen_traits_q
) {
// using SmemLayoutrO = typename Ktraits::SmemLayoutrO;
// using TiledCopyrO = typename Ktraits::TiledCopyrO;
auto [m_block, bidh, bidb] = block_coord;
TiledCopyrO rmem_tiled_copy_O;
Tensor sOacc = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutrO{});
auto rmem_thr_copy_O = rmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor taccOsO = rmem_thr_copy_O.partition_D(sOacc);
Tensor tOrO_out = flash::convert_type<Element>(tOrO); // Element is Ktraits::OutputType
Tensor taccOrO = make_tensor(tOrO_out.data(), shape(taccOsO));
// Make sure all WGs have finished reading V
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
cute::copy(rmem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
// taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
int const seqlen_q = [&] {
if constexpr(Seqlen_traits::kUseVarSeqLen) { return seqlen_traits_q.actual_seq_len; }
else { return shape<2>(epilogue_params.layout_LSE); }
}();
if (get<1>(taccOcO_row(_0{})) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}
int write_warp_idx = kNWarps - 1;
if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
TiledCopyO gmem_tiled_copy_O;
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O,
epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO,
m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
);
}
CUTLASS_DEVICE void CUTLASS_DEVICE void
store_tail() { store_tail() {
tma_store_wait<0>(); tma_store_wait<0>();
......
...@@ -249,7 +249,13 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split ...@@ -249,7 +249,13 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
} }
} }
} else { } else {
// run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream); if (params.d == 64) {
run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
} else if (params.d == 128) {
run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
} else if (params.d == 256) {
run_mha_fwd_<cutlass::float_e4m3_t, 256>(params, stream);
}
} }
} }
...@@ -266,12 +272,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -266,12 +272,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer."); TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
auto q_dtype = q.dtype(); auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type for now"); // "FlashAttention only support fp16 and bf16 data type for now");
// TODO: will add e4m3 later // TODO: will add e4m3 later
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn, // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
// "FlashAttention only support fp16 and bf16 data type"); // "FlashAttention only support fp16 and bf16 data type");
// "FlashAttention only support fp16 and fp8 (e4m3) data type for now"); // "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
...@@ -317,13 +323,21 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -317,13 +323,21 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
at::Tensor out; at::Tensor out;
if (out_.has_value()) { if (out_.has_value()) {
out = out_.value(); out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn
? (out.dtype() == at::kHalf)
: (out.dtype() == q_dtype),
"Output must have the same dtype as input dtype if dtype is "
"not fp8, or fp16 for fp8 input.");
CHECK_DEVICE(out); CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else { } else {
out = torch::empty_like(q_padded); if (q_dtype == at::ScalarType::Float8_e4m3fn)
out = torch::empty_like(q_padded, at::kHalf);
else
out = torch::empty_like(q_padded);
} }
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
...@@ -534,13 +548,13 @@ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -534,13 +548,13 @@ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
// run_mha_bwd_<elem_type, kHeadDim>(params, stream); // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
// }); // });
// }); // });
if (params.d == 64) { // if (params.d == 64) {
run_mha_bwd_<cutlass::half_t, 64>(params, stream); // run_mha_bwd_<cutlass::half_t, 64>(params, stream);
} else if (params.d == 128) { // } else if (params.d == 128) {
run_mha_bwd_<cutlass::half_t, 128>(params, stream); // run_mha_bwd_<cutlass::half_t, 128>(params, stream);
} else { // } else {
run_mha_bwd_<cutlass::half_t, 256>(params, stream); // run_mha_bwd_<cutlass::half_t, 256>(params, stream);
} // }
} }
std::vector<at::Tensor> std::vector<at::Tensor>
......
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::float_e4m3_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128_fp8<cutlass::float_e4m3_t>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::float_e4m3_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256_fp8<cutlass::float_e4m3_t>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::float_e4m3_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64_fp8<cutlass::float_e4m3_t>(params, stream);
}
...@@ -188,4 +188,198 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, ...@@ -188,4 +188,198 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
} }
template <typename Ktraits, bool Is_causal, typename TileScheduler, typename Seqlen_traits>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>::Params const mainloop_params,
CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits>::Params const epilogue_params,
CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k
) {
using Element = typename Ktraits::Element;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static_assert(Ktraits::Is_WS);
static constexpr bool Is_WS = Ktraits::Is_WS;
static constexpr bool kUseVarSeqLen = Seqlen_traits::kUseVarSeqLen;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
// static constexpr int kBlockN = Ktraits::kBlockN;
// static constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128;
// for now, disable for hdim 128 causal to avoid perf regression with register spilling
static constexpr bool Use_max_offset = !(Is_causal && Ktraits::kHeadDim == 128);
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>;
using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using MainloopPipelineVt = typename Ktraits::MainloopPipelineNoTMA;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineParamsVt = typename MainloopPipelineVt::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
// additional pipeline to synchronize out-of-place smem transpose of V
PipelineParamsVt pipeline_params_vt;
pipeline_params_vt.producer_arv_count = NumCopyThreads;
pipeline_params_vt.consumer_arv_count = NumMmaThreads;
MainloopPipelineVt pipeline_vt(shared_storage.pipeline_vt, pipeline_params_vt);
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1 /*numThreads*/);
shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
}
// We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
// pipeline_v has producer warpgroup for its consumer in fp8 kernel
pipeline_params.num_consumers = NumCopyThreads;
pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue;
// We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_read, smem_pipe_release;
int work_idx = 0;
TileScheduler scheduler(&shared_storage.tile_count_semaphore);
for (auto work_tile_info = scheduler.get_initial_work();
work_tile_info.is_valid(scheduler_params);
work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord;
if constexpr(kUseVarSeqLen) {
seqlen_traits_q.init(bidb);
seqlen_traits_k.init(bidb);
if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
continue;
}
}
int n_block_max = collective_mainloop.get_n_block_max(
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
if constexpr(Is_causal) {
if(n_block_max <= 0) {
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
scheduler.broadcast_next_work(work_tile_info);
// need to sync producer warpgroup
cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
continue;
}
}
collective_mainloop.load_fp8(
mainloop_params, pipeline_k, pipeline_v, pipeline_vt,
smem_pipe_write, smem_pipe_read, shared_storage,
scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
seqlen_traits_q, seqlen_traits_k);
++work_idx;
// don't need to sync producer warpgroup here
// if constexpr (Is_causal) {
// cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/); }
}
collective_mainloop.load_tail_one_write(pipeline_k, pipeline_v, smem_pipe_write);
} else { // Consumer
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();
TileScheduler scheduler(&shared_storage.tile_count_semaphore);
// Initialize matmul objects.
typename Ktraits::TiledMma1 tiled_mma1;
PipelineState smem_pipe_read;
PipelineState smem_pipe_release;
collective_mainloop.mma_init();
scheduler.init_consumer();
int work_idx = 0;
CUTLASS_PRAGMA_NO_UNROLL
for (auto work_tile_info = scheduler.get_initial_work();
work_tile_info.is_valid(scheduler_params);
work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
// Attention output (GEMM-II) accumulator.
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), Use_max_offset> softmax;
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord;
if constexpr(kUseVarSeqLen) {
seqlen_traits_q.init(bidb);
seqlen_traits_k.init(bidb);
if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
continue;
}
}
int n_block_max = collective_mainloop.get_n_block_max(
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
if constexpr(Is_causal) {
if(n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
continue;
}
}
collective_mainloop.mma_fp8<Delay_V_release>(
mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read, smem_pipe_release,
tOrO, softmax, n_block_max,
threadIdx.x - NumCopyThreads, work_idx, m_block,
shared_storage, seqlen_traits_q, seqlen_traits_k);
#ifndef NO_FP8_COLUMN_PERMUTE
collective_epilogue.store_fp8(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
#else
collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
#endif
++work_idx;
}
collective_epilogue.store_tail();
}
}
} // namespace flash } // namespace flash
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits> template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
using OutputType = typename Kernel_traits::OutputType;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK; using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK; using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
...@@ -32,7 +33,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -32,7 +33,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
flash::SingleTileScheduler, flash::SingleTileScheduler,
std::conditional_t<!Is_causal, std::conditional_t<!Is_causal,
flash::StaticPersistentTileScheduler, flash::StaticPersistentTileScheduler,
flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup> flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup, Kernel_traits::NumProducerThreads>
>>; >>;
// using Scheduler = flash::SingleTileScheduler; // using Scheduler = flash::SingleTileScheduler;
Seqlen_traits seqlen_traits_q( Seqlen_traits seqlen_traits_q(
...@@ -60,7 +61,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -60,7 +61,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
}); });
typename CollectiveEpilogue::Params epilogue_params = typename CollectiveEpilogue::Params epilogue_params =
CollectiveEpilogue::to_underlying_arguments({ CollectiveEpilogue::to_underlying_arguments({
static_cast<Element*>(params.o_ptr), static_cast<OutputType*>(params.o_ptr),
seqlen_traits_q.get_gmem_layout( seqlen_traits_q.get_gmem_layout(
params.seqlen_q, params.d, params.h, params.b, params.seqlen_q, params.d, params.h, params.b,
params.o_row_stride, params.o_head_stride, params.o_batch_stride params.o_row_stride, params.o_head_stride, params.o_batch_stride
...@@ -78,12 +79,16 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -78,12 +79,16 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// Get the ptr to kernel function. // Get the ptr to kernel function.
void *kernel; void *kernel;
kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>; if constexpr(cutlass::sizeof_bits_v<Element> == 8)
kernel = (void *)flash::compute_attn_ws_fp8<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
else
kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage); int smem_size = sizeof(typename Kernel_traits::SharedStorage);
// int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q)); // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
// int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k)); // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
// int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v)); // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v); // int smem_size_o = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_o));
// printf("smem_size = %d, q = %d, k = %d, v = %d, o = %d.\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_o);
if (smem_size >= 48 * 1024) { if (smem_size >= 48 * 1024) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
} }
...@@ -151,3 +156,60 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -151,3 +156,60 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
}); });
}); });
} }
template<typename T>
void run_mha_fwd_hdim64_fp8(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
constexpr static int kBlockM = 192;
constexpr static int kBlockN = 128;
constexpr static int kNWarps = 4 + kBlockM/16;
constexpr static int kStages = 4;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
!Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
});
});
});
}
template<typename T>
void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
constexpr static int kBlockM = 128;
constexpr static int kBlockN = 256;
constexpr static int kNWarps = 4 + kBlockM/16;
constexpr static int kStages = 2;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
!Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
});
});
});
}
template<typename T>
void run_mha_fwd_hdim256_fp8(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
constexpr static int kBlockM = 128;
constexpr static int kBlockN = 128;
constexpr static int kNWarps = 4 + kBlockM/16;
constexpr static int kStages = 2;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
!Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
});
});
});
}
...@@ -33,17 +33,41 @@ struct SharedStorageQKVO { ...@@ -33,17 +33,41 @@ struct SharedStorageQKVO {
}; };
}; };
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
struct SharedStorageQKVOVt {
struct {
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
union {
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
};
};
struct {
cutlass::arch::ClusterTransactionBarrier barrier_Q;
cutlass::arch::ClusterBarrier barrier_O;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
int tile_count_semaphore;
};
};
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false, template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
int kClusterM_ = 1, typename elem_type=cutlass::half_t> int kClusterM_ = 1, typename elem_type=cutlass::half_t>
struct Flash_fwd_kernel_traits { struct Flash_fwd_kernel_traits {
using Element = elem_type; using Element = elem_type;
using ElementAccum = float; using ElementAccum = float;
using OutputType = elem_type;
using index_t = int64_t; using index_t = int64_t;
// The number of threads. // The number of threads.
static constexpr int kNWarps = kNWarps_; static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16); static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
...@@ -88,9 +112,16 @@ struct Flash_fwd_kernel_traits { ...@@ -88,9 +112,16 @@ struct Flash_fwd_kernel_traits {
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutV = using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{}, decltype(tile_to_shape(SmemLayoutAtomV{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{}))); make_shape(get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), Int<kStages>{})));
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVt =
decltype(composition(SmemLayoutV{},
make_ordered_layout(
make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
Step<_2, _1, _3>{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element, using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
...@@ -100,11 +131,122 @@ struct Flash_fwd_kernel_traits { ...@@ -100,11 +131,122 @@ struct Flash_fwd_kernel_traits {
SmemLayoutK, SmemLayoutV, SmemLayoutO>; SmemLayoutK, SmemLayoutV, SmemLayoutO>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>; using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>; using PipelineState = typename cutlass::PipelineState<kStages>;
// using BarrierType = typename MainloopPipeline::ProducerBarrierType; // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
}; };
// Traits struct for fp8 kernel with in-kernel transpose
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t>
struct Flash_fwd_kernel_traits_fp8 {
using Element = elem_type;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
using ElementAccum = float;
using OutputType = cutlass::half_t;
using index_t = int64_t;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
static_assert(kNWarps_ == 12 || kNWarps_ == 16);
static constexpr bool Is_WS = true;
static_assert(!Is_Q_in_regs, "Warp-specialization does not support Q in registers");
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
static constexpr int kClusterM = kClusterM_;
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
static constexpr int kStages = kStages_;
static_assert(kStages > 1);
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma0 = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma1 = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{}))>(),
AtomLayoutMNK{}));
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using TransposeShapeAtomV = Shape<_64, _64>;
using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
// for fp8 in-kernel transpose -- src layout
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
// For fp8, this is the memory transpose.
using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
using SmemLayoutVt =
decltype(tile_to_shape(SmemLayoutAtomVt{},
make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
// for fp8 in-kernel transpose -- dst layout
using SmemLayoutVtTrans =
decltype(composition(SmemLayoutVt{},
make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
#ifndef NO_FP8_COLUMN_PERMUTE
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
#else
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
#endif
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{},
shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
// used for rmem -> smem O copy in fp8 kernel to undo column permutation
using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
Stride<_4, _32, _1, _0>>;
using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
Stride<_0, _2, Stride<_4, _1>, _8>>;
using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, OutputType>{},
ThreadLayoutrO{}, ValueLayoutrO{}));
using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SharedStorage = SharedStorageQKVOVt<kStages, Element, Element, OutputType, SmemLayoutQ,
SmemLayoutK, SmemLayoutV, SmemLayoutO>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
// using BarrierType = typename MainloopPipeline::ProducerBarrierType;
};
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO, template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
......
...@@ -21,6 +21,64 @@ namespace flash { ...@@ -21,6 +21,64 @@ namespace flash {
using namespace cute; using namespace cute;
// 4 warps
struct SmemTransposeFp8_64x64 {
using Element = cutlass::float_e4m3_t;
using ldsm_thread_shape = Shape<_4, _1, _8, _4>;
using ldsm_value_shape = Shape<_2, _8, _2, _1>;
using ldsm_value_stride = Stride<_2, _4, _1, _0>;
using TiledCopyLDSM = decltype(make_tiled_copy(
Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<ldsm_thread_shape>{},
Layout<ldsm_value_shape, ldsm_value_stride>{}));
TiledCopyLDSM tiled_copy_ldsm;
using stsm_thread_shape = Shape<_4, _1, _8, _4>;
// using stsm_thread_stride = Stride<_1, _0, _4, _32>;
#ifndef NO_FP8_COLUMN_PERMUTE
using stsm_value_shape = Shape<_4, _4, _1, _2>;
using stsm_value_stride = Stride<_1, _8, _0, _4>;
#else
using stsm_value_shape = Shape<_4, _4, _2, _1>;
using stsm_value_stride = Stride<_1, _8, _4, _0>;
#endif
using TiledCopySTSM =
decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{},
Layout<stsm_thread_shape>{},
Layout<stsm_value_shape, stsm_value_stride>{}));
TiledCopySTSM tiled_copy_stsm;
template <class SmemTensor, class SmemTensorOut>
CUTLASS_DEVICE void operator()(SmemTensor &&s_in, SmemTensorOut &&s_out) {
using namespace cute;
auto tid = threadIdx.x;
auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
auto tXsX = thr_copy_ldsm.partition_S(s_in);
auto tXrX = make_tensor<Element>(shape(tXsX));
auto tXsX_out = thr_copy_stsm.partition_D(s_out);
cute::copy(tiled_copy_ldsm, tXsX, tXrX);
auto data = tXrX.data();
// size(tXrX) == 32
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size(tXrX); n += 8) {
uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
auto upper = data_32bit[0];
auto lower = data_32bit[1];
data_32bit[0] = __byte_perm(upper, lower, 0x6420);
data_32bit[1] = __byte_perm(upper, lower, 0x7531);
}
cute::copy(tiled_copy_stsm, tXrX, tXsX_out);
}
};
template <typename Ktraits, bool Is_causal, typename Seqlen_traits> template <typename Ktraits, bool Is_causal, typename Seqlen_traits>
struct CollectiveMainloopFwd { struct CollectiveMainloopFwd {
...@@ -29,40 +87,15 @@ struct CollectiveMainloopFwd { ...@@ -29,40 +87,15 @@ struct CollectiveMainloopFwd {
using ClusterShape = typename Ktraits::ClusterShape_MNK; using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kStages = Ktraits::kStages; static constexpr int kStages = Ktraits::kStages;
static constexpr int kHeadDim = Ktraits::kHeadDim; static constexpr int kHeadDim = Ktraits::kHeadDim;
using GmemTiledCopyQ = cute::SM90_TMA_LOAD; using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element, using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutV = SmemLayoutK;
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVt =
decltype(cute::composition(SmemLayoutV{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
// using SmemLayoutAtomVt = cute::GMMA::Layout_MN_SW128_Atom<Element>;
// using SmemLayoutVt =
// decltype(tile_to_shape(SmemLayoutAtomVt{},
// make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
// Step<_2, _1, _3>{})); // This gives correct results, without Step it's wrong
// using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::MN, Element,
// decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
// using SmemLayoutVt =
// decltype(tile_to_shape(SmemLayoutAtomVt{},
// make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
// using SmemLayoutAtomVTMA = cute::GMMA::Layout_K_SW128_Atom<Element>;
// using SmemLayoutVTMA =
// decltype(tile_to_shape(SmemLayoutAtomVTMA{},
// make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using TMA_Q = decltype(make_tma_copy( using TMA_Q = decltype(make_tma_copy(
GmemTiledCopyQ{}, GmemTiledCopyQ{},
...@@ -75,7 +108,7 @@ struct CollectiveMainloopFwd { ...@@ -75,7 +108,7 @@ struct CollectiveMainloopFwd {
select<0, 2>(TileShape_MNK{}), select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for Q _1{})); // no mcast for Q
using TMA_KV = decltype(make_tma_copy( using TMA_K = decltype(make_tma_copy(
GmemTiledCopyKV{}, GmemTiledCopyKV{},
make_tensor( make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)), make_gmem_ptr(static_cast<Element const*>(nullptr)),
...@@ -86,8 +119,21 @@ struct CollectiveMainloopFwd { ...@@ -86,8 +119,21 @@ struct CollectiveMainloopFwd {
select<1, 2>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
// TMA_V may differ from TMA_K for fp8 kernel (e.g. swizzling mode)
using TMA_V = decltype(make_tma_copy(
GmemTiledCopyKV{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
typename Seqlen_traits::StrideT{}
),
take<0, 2>(SmemLayoutV{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
using MainloopPipeline = typename Ktraits::MainloopPipeline; using MainloopPipeline = typename Ktraits::MainloopPipeline;
using MainloopPipelineNoTMA = typename Ktraits::MainloopPipelineNoTMA;
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState; using PipelineState = typename MainloopPipeline::PipelineState;
...@@ -95,7 +141,10 @@ struct CollectiveMainloopFwd { ...@@ -95,7 +141,10 @@ struct CollectiveMainloopFwd {
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8); static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128; // static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
static constexpr bool UseSchedulerBarrier =
cutlass::sizeof_bits_v<Element> == 8 ? kHeadDim >= 128
: kHeadDim <= 128;
// Host side kernel arguments // Host side kernel arguments
struct Arguments { struct Arguments {
...@@ -114,8 +163,9 @@ struct CollectiveMainloopFwd { ...@@ -114,8 +163,9 @@ struct CollectiveMainloopFwd {
typename Seqlen_traits::LayoutT layout_K; typename Seqlen_traits::LayoutT layout_K;
typename Seqlen_traits::LayoutT layout_V; typename Seqlen_traits::LayoutT layout_V;
cutlass::FastDivmod qhead_per_khead_divmod; cutlass::FastDivmod qhead_per_khead_divmod;
TMA_Q tma_load_Q; TMA_Q tma_load_Q;
TMA_KV tma_load_K, tma_load_V; TMA_K tma_load_K;
TMA_V tma_load_V;
float const softmax_scale_log2; float const softmax_scale_log2;
}; };
...@@ -130,14 +180,14 @@ struct CollectiveMainloopFwd { ...@@ -130,14 +180,14 @@ struct CollectiveMainloopFwd {
select<0, 2>(TileShape_MNK{}), select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for Q _1{}); // no mcast for Q
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K); Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
TMA_KV tma_load_K = make_tma_copy( TMA_K tma_load_K = make_tma_copy(
GmemTiledCopyKV{}, GmemTiledCopyKV{},
mK, mK,
SmemLayoutK{}(_, _, _0{}), SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V); Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
TMA_KV tma_load_V = make_tma_copy( TMA_V tma_load_V = make_tma_copy(
GmemTiledCopyKV{}, GmemTiledCopyKV{},
mV, mV,
SmemLayoutV{}(_, _, _0{}), SmemLayoutV{}(_, _, _0{}),
...@@ -164,9 +214,9 @@ struct CollectiveMainloopFwd { ...@@ -164,9 +214,9 @@ struct CollectiveMainloopFwd {
const Seqlen_traits& seqlen_traits_k const Seqlen_traits& seqlen_traits_k
) { ) {
static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{});
int const seqlen_q = seqlen_traits_q.actual_seq_len; int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
int const seqlen_k = seqlen_traits_k.actual_seq_len; int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
int n_block_max = cute::ceil_div(seqlen_k, kBlockN); int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
if constexpr (Is_causal) { if constexpr (Is_causal) {
n_block_max = std::min(n_block_max, n_block_max = std::min(n_block_max,
...@@ -279,13 +329,242 @@ struct CollectiveMainloopFwd { ...@@ -279,13 +329,242 @@ struct CollectiveMainloopFwd {
scheduler.broadcast_next_work(work_tile_info); scheduler.broadcast_next_work(work_tile_info);
} }
template <typename Scheduler, typename SharedStorage>
CUTLASS_DEVICE void
load_fp8(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
MainloopPipelineNoTMA pipeline_vt,
PipelineState& smem_pipe_write,
PipelineState& smem_pipe_read,
SharedStorage &shared_storage,
Scheduler& scheduler,
typename Scheduler::Params const& scheduler_params,
typename Scheduler::WorkTileInfo& work_tile_info,
cute::tuple<int32_t, int32_t, int32_t> block_coord,
int work_idx,
const Seqlen_traits& seqlen_traits_q,
const Seqlen_traits& seqlen_traits_k
) {
using SmemLayoutTransposeV = typename Ktraits::SmemLayoutTransposeV;
using SmemLayoutTransposeVt = typename Ktraits::SmemLayoutTransposeVt;
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
Tensor sV_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutTransposeV{}));
Tensor sVt_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutTransposeVt{}));
auto smem_transpose_V = SmemTransposeFp8_64x64();
auto do_transpose_V = [&](int stage) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < shape<2>(SmemLayoutTransposeV{}); ++j) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < shape<1>(SmemLayoutTransposeV{}); ++i) {
smem_transpose_V(flatten(sV_divide(_, i, j, stage)),
flatten(sVt_divide(_, i, j, stage)));
}
}
};
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
auto [m_block, bidh, bidb] = block_coord;
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
// Prepare the TMA loads
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K)
Tensor gK = seqlen_traits_k.get_local_tile_tensor(
mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
Tensor gV = seqlen_traits_k.get_local_tile_tensor(
mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE)
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE)
uint16_t mcast_mask_kv = 0;
if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
}
}
int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
int n_block = n_block_max - 1;
int lane_predicate = cute::elect_one_sync();
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0 && lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
}
// Wait for the MMA warpgroups to say that smem_q is ready
// for fp8, change from NumThreadsPerWarp to NumThreadsPerWarpGroup
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
if constexpr(Is_causal) {
if (warp_idx_in_warpgroup == 0 && lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
pipeline_v.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
}
shared_storage.barrier_O.wait((work_idx + 1) % 2);
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < kStages && n_block > 0; ++iter, --n_block) {
pipeline_v.consumer_wait(smem_pipe_read);
// pipeline_vt.producer_acquire(smem_pipe_write);
do_transpose_V(smem_pipe_read.index());
pipeline_vt.producer_commit(smem_pipe_write);
pipeline_v.consumer_release(smem_pipe_read);
++smem_pipe_write;
++smem_pipe_read;
if (warp_idx_in_warpgroup == 0 && lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index()));
pipeline_v.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index()));
}
}
#pragma unroll 2
for (; n_block > 0; --n_block) {
pipeline_v.consumer_wait(smem_pipe_read);
pipeline_vt.producer_acquire(smem_pipe_write);
do_transpose_V(smem_pipe_read.index());
pipeline_vt.producer_commit(smem_pipe_write);
pipeline_v.consumer_release(smem_pipe_read);
++smem_pipe_write;
++smem_pipe_read;
if (warp_idx_in_warpgroup == 0 && lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index()));
pipeline_v.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index()));
}
}
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
scheduler.broadcast_next_work(work_tile_info);
pipeline_v.consumer_wait(smem_pipe_read);
if (n_block_max > kStages)
pipeline_vt.producer_acquire(smem_pipe_write);
do_transpose_V(smem_pipe_read.index());
pipeline_vt.producer_commit(smem_pipe_write);
pipeline_v.consumer_release(smem_pipe_read);
++smem_pipe_write;
++smem_pipe_read;
} else {
if (warp_idx_in_warpgroup == 0 && lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
pipeline_v.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
}
// With fp8 kernel, smem_o is in union with smem_v_out,
// so could use NamedBarrier instead of ClusterBarrier.
// But, this doesn't appear to have any benefit.
shared_storage.barrier_O.wait((work_idx + 1) % 2);
pipeline_v.consumer_wait(smem_pipe_read);
// pipeline_vt.producer_acquire(smem_pipe_write);
do_transpose_V(smem_pipe_read.index());
pipeline_vt.producer_commit(smem_pipe_write);
pipeline_v.consumer_release(smem_pipe_read);
++smem_pipe_write;
++smem_pipe_read;
--n_block;
constexpr int extra_iterations = kStages - 1;
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter) {
if (warp_idx_in_warpgroup == 0 && lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
pipeline_v.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
}
pipeline_v.consumer_wait(smem_pipe_read);
// pipeline_vt.producer_acquire(smem_pipe_write);
do_transpose_V(smem_pipe_read.index());
pipeline_vt.producer_commit(smem_pipe_write);
pipeline_v.consumer_release(smem_pipe_read);
++smem_pipe_write;
++smem_pipe_read;
--n_block;
}
// CUTLASS_PRAGMA_NO_UNROLL
#pragma unroll 2
for (; n_block >= 0; --n_block) {
if (warp_idx_in_warpgroup == 0 && lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
pipeline_v.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
}
pipeline_v.consumer_wait(smem_pipe_read);
pipeline_vt.producer_acquire(smem_pipe_write);
do_transpose_V(smem_pipe_read.index());
pipeline_vt.producer_commit(smem_pipe_write);
pipeline_v.consumer_release(smem_pipe_read);
++smem_pipe_write;
++smem_pipe_read;
}
// scheduler.prefetch_next_work(scheduler_params, work_tile_info);
// scheduler.broadcast_next_work(work_tile_info);
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void CUTLASS_DEVICE void
load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) { PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) {
int lane_predicate = cute::elect_one_sync(); int lane_predicate = cute::elect_one_sync();
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
// Issue the epilogue waits // Issue the epilogue waits
if (lane_predicate) { if (warp_idx_in_warpgroup == 0 && lane_predicate) {
/* This helps avoid early exit of blocks in Cluster /* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was still inverted from make_producer_start_state * then would just be acquired since the phase was still inverted from make_producer_start_state
...@@ -295,6 +574,23 @@ struct CollectiveMainloopFwd { ...@@ -295,6 +574,23 @@ struct CollectiveMainloopFwd {
} }
} }
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail_one_write(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
// Issue the epilogue waits
if (warp_idx_in_warpgroup == 0 && lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was still inverted from make_producer_start_state
*/
pipeline_k.producer_tail(smem_pipe_write);
pipeline_v.producer_tail(smem_pipe_write);
}
}
CUTLASS_DEVICE void CUTLASS_DEVICE void
warp_scheduler_barrier_sync() { warp_scheduler_barrier_sync() {
if constexpr (UseSchedulerBarrier) { if constexpr (UseSchedulerBarrier) {
...@@ -317,7 +613,7 @@ struct CollectiveMainloopFwd { ...@@ -317,7 +613,7 @@ struct CollectiveMainloopFwd {
CUTLASS_DEVICE void CUTLASS_DEVICE void
mma_init() { mma_init() {
// Tell producer (warp 0) that smem_q is ready // Tell producer (warp 0) that smem_q is ready
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/); cutlass::arch::NamedBarrier::arrive(NumMmaThreads + Ktraits::NumProducerThreads, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
if constexpr (!UseSchedulerBarrier) { return; } if constexpr (!UseSchedulerBarrier) { return; }
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if (cutlass::canonical_warp_group_idx() > 1) { if (cutlass::canonical_warp_group_idx() > 1) {
...@@ -387,6 +683,7 @@ struct CollectiveMainloopFwd { ...@@ -387,6 +683,7 @@ struct CollectiveMainloopFwd {
warp_scheduler_barrier_sync(); warp_scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
warp_scheduler_barrier_arrive(); warp_scheduler_barrier_arrive();
if (work_idx != 0) { if (work_idx != 0) {
int lane_predicate = cute::elect_one_sync(); int lane_predicate = cute::elect_one_sync();
if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
...@@ -495,6 +792,234 @@ struct CollectiveMainloopFwd { ...@@ -495,6 +792,234 @@ struct CollectiveMainloopFwd {
return; return;
} }
template <bool Delay_V_release = false, typename SharedStorage, typename FrgTensorO, typename Softmax>
CUTLASS_DEVICE void
mma_fp8(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipelineNoTMA pipeline_vt,
PipelineState& smem_pipe_read,
PipelineState& smem_pipe_release,
FrgTensorO& tOrO,
Softmax& softmax,
int n_block_count,
int thread_idx,
int work_idx,
int m_block,
SharedStorage& shared_storage,
const Seqlen_traits& seqlen_traits_q,
const Seqlen_traits& seqlen_traits_k
) {
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kBlockN = get<1>(TileShape_MNK{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutVt{});
typename Ktraits::TiledMma0 tiled_mma0;
typename Ktraits::TiledMma1 tiled_mma1;
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
// Allocate "fragments/descriptors" for first matmul.
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
Tensor tSrK = threadMma0.partition_fragment_B(sK);
// Allocate "fragments/descriptors" for second matmul.
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
// workaround for fp8 only perf regression pending change to seqlen traits class
int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
int n_block = n_block_count - 1;
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read);
warp_scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
if (work_idx != 0) {
int lane_predicate = cute::elect_one_sync();
if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
tma_store_wait<0>();
#pragma unroll
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
shared_storage.barrier_O.arrive(cta_id, lane_predicate);
}
}
}
warpgroup_wait<0>();
warp_scheduler_barrier_arrive();
pipeline_k.consumer_release(smem_pipe_read);
auto col_limit_causal = [&](int row, int n_block) {
return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
};
{
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if constexpr (!Is_causal) { // Just masking based on col
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
} else { // mask based on both row and col
if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN,
col_limit_causal(int(get<0>(tScS(i))), n_block))) {
tSrS(i) = -INFINITY;
}
}
}
}
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);
Tensor scores_scale = make_fragment_like(softmax.row_max);
clear(scores_scale);
consumer_wait(pipeline_vt, smem_pipe_read);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
++smem_pipe_read;
--n_block;
constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM, kBlockN);
if constexpr(Is_causal) {
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read);
warp_scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block)) {
tSrS(i) = -INFINITY;
}
}
warp_scheduler_barrier_arrive();
pipeline_k.consumer_release(smem_pipe_read);
consumer_wait(pipeline_vt, smem_pipe_read);
cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);
if constexpr(Delay_V_release) {
pipeline_vt.consumer_release(smem_pipe_release);
++smem_pipe_release;
}
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
++smem_pipe_read;
}
} else {
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read);
if constexpr(Delay_V_release) {
pipeline_vt.consumer_release(smem_pipe_release);
++smem_pipe_release;
}
warp_scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
warp_scheduler_barrier_arrive();
if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
else { consumer_wait(pipeline_vt, smem_pipe_read); }
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);
if constexpr (Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
else { consumer_wait(pipeline_vt, smem_pipe_read); }
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
++smem_pipe_read;
}
}
if constexpr(Delay_V_release) {
warp_scheduler_barrier_sync();
CUTLASS_PRAGMA_NO_UNROLL
for (; n_block >= 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read);
pipeline_vt.consumer_release(smem_pipe_release);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
warp_scheduler_barrier_arrive();
warpgroup_wait<0>();
consumer_wait(pipeline_vt, smem_pipe_read);
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);
pipeline_k.consumer_release(smem_pipe_read);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
warp_scheduler_barrier_sync();
warpgroup_wait<0>();
++smem_pipe_read;
++smem_pipe_release;
}
warp_scheduler_barrier_arrive();
pipeline_vt.consumer_release(smem_pipe_release);
++smem_pipe_release;
} else {
if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
CUTLASS_PRAGMA_NO_UNROLL
for (; n_block >= 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read);
if constexpr (kHeadDim == 256) { warp_scheduler_barrier_sync(); }
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
warp_scheduler_barrier_arrive();
pipeline_k.consumer_release(smem_pipe_read);
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);
consumer_wait(pipeline_vt, smem_pipe_read);
if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
pipeline_vt.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); }
}
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
return;
}
}; };
} // namespace flash } // namespace flash
...@@ -18,6 +18,7 @@ enum class FwdNamedBarriers { ...@@ -18,6 +18,7 @@ enum class FwdNamedBarriers {
WarpSchedulerWG1 = 4, WarpSchedulerWG1 = 4,
WarpSchedulerWG2 = 5, WarpSchedulerWG2 = 5,
WarpSchedulerWG3 = 6, WarpSchedulerWG3 = 6,
ProducerWG = 7
}; };
} // flash } // flash
\ No newline at end of file
...@@ -119,7 +119,9 @@ if not SKIP_CUDA_BUILD: ...@@ -119,7 +119,9 @@ if not SKIP_CUDA_BUILD:
"flash_bwd_hdim64_fp16_sm90.cu", "flash_bwd_hdim64_fp16_sm90.cu",
"flash_bwd_hdim128_fp16_sm90.cu", "flash_bwd_hdim128_fp16_sm90.cu",
"flash_bwd_hdim256_fp16_sm90.cu", "flash_bwd_hdim256_fp16_sm90.cu",
# "flash_fwd_hdim128_e4m3_sm90.cu", "flash_fwd_hdim64_e4m3_sm90.cu",
"flash_fwd_hdim128_e4m3_sm90.cu",
"flash_fwd_hdim256_e4m3_sm90.cu"
] ]
nvcc_flags = [ nvcc_flags = [
"-O3", "-O3",
...@@ -134,15 +136,11 @@ if not SKIP_CUDA_BUILD: ...@@ -134,15 +136,11 @@ if not SKIP_CUDA_BUILD:
"--expt-relaxed-constexpr", "--expt-relaxed-constexpr",
"--expt-extended-lambda", "--expt-extended-lambda",
"--use_fast_math", "--use_fast_math",
# "--ptxas-options=-v", # printing out number of registers "--ptxas-options=-v", # printing out number of registers
"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers
"-lineinfo", "-lineinfo",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging
"-DNDEBUG", # Important, otherwise performance is severely impacted "-DNDEBUG", # Important, otherwise performance is severely impacted
"-DQBLKSIZE=128",
"-DKBLKSIZE=128",
"-DCTA256",
"-DDQINRMEM",
] ]
include_dirs = [ include_dirs = [
# Path(this_dir) / "fmha-pipeline", # Path(this_dir) / "fmha-pipeline",
...@@ -161,7 +159,7 @@ if not SKIP_CUDA_BUILD: ...@@ -161,7 +159,7 @@ if not SKIP_CUDA_BUILD:
"cxx": ["-O3", "-std=c++17"], "cxx": ["-O3", "-std=c++17"],
# "cxx": ["-O0", "-std=c++17"], # "cxx": ["-O0", "-std=c++17"],
"nvcc": append_nvcc_threads( "nvcc": append_nvcc_threads(
nvcc_flags + ["-DEXECMODE=0"] + cc_flag nvcc_flags + cc_flag
), ),
}, },
include_dirs=include_dirs, include_dirs=include_dirs,
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include "utils.h" #include "utils.h"
#include "cutlass/fast_math.h"
namespace flash { namespace flash {
using namespace cute; using namespace cute;
...@@ -100,8 +102,10 @@ __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &ten ...@@ -100,8 +102,10 @@ __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &ten
} }
// Apply the exp to all the elements. // Apply the exp to all the elements.
template <bool Scale_max=true, bool Check_inf=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <bool Scale_max=true, bool Check_inf=true, bool Use_max_offset=false,
typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) { __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
constexpr static float max_offset = Use_max_offset ? 8.0f : 0.0f;
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
...@@ -111,8 +115,8 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso ...@@ -111,8 +115,8 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
// We don't want (-inf - (-inf)) since that would give NaN. // We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64. // If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = Check_inf const float max_scaled = Check_inf
? (max(mi) == -INFINITY ? 0.f : (max(mi) * (Scale_max ? scale : float(M_LOG2E)))) ? (max(mi) == -INFINITY ? 0.f : (max(mi) * (Scale_max ? scale : float(M_LOG2E))) - max_offset)
: (max(mi) * (Scale_max ? scale : float(M_LOG2E))); : (max(mi) * (Scale_max ? scale : float(M_LOG2E)) - max_offset);
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) { for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
...@@ -125,8 +129,11 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso ...@@ -125,8 +129,11 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kNRows> template <int kNRows, bool Use_max_offset_ = false>
struct Softmax { struct Softmax {
constexpr static bool Use_max_offset = Use_max_offset_;
// constexpr static float max_offset = Use_max_offset ? 8.0f : 0.0f;
// constexpr static float max_offset_E = max_offset * float(M_LN2);
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{})); using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum; TensorT row_max, row_sum;
...@@ -166,7 +173,7 @@ struct Softmax { ...@@ -166,7 +173,7 @@ struct Softmax {
TensorT scores_scale; TensorT scores_scale;
if constexpr (Is_first) { if constexpr (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max); flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2); flash::template scale_apply_exp2</*Scale_max=*/true, /*Check_inf=*/true, Use_max_offset>(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum); flash::reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
cute::fill(scores_scale, 1.f); cute::fill(scores_scale, 1.f);
// if (cute::thread0()) { print_tensor(scores); printf("\n scale = %f\n", softmax_scale_log2); print_tensor(row_sum); } // if (cute::thread0()) { print_tensor(scores); printf("\n scale = %f\n", softmax_scale_log2); print_tensor(row_sum); }
...@@ -183,16 +190,17 @@ struct Softmax { ...@@ -183,16 +190,17 @@ struct Softmax {
// scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); // scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
// row_sum(mi) *= scores_scale(mi); // row_sum(mi) *= scores_scale(mi);
// } // }
flash::template scale_apply_exp2</*Scale_max=*/true, Check_inf>(scores, row_max, softmax_scale_log2); flash::template scale_apply_exp2</*Scale_max=*/true, Check_inf, Use_max_offset>(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum. // We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax. // We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum); flash::reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
} }
return scores_scale; return scores_scale;
}; };
template<bool Is_dropout=false, bool Split=false, typename Tensor0> template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT finalize(Tensor0 &acc_s, float softmax_scale_log2, float rp_dropout=1.0) { __forceinline__ __device__ TensorT finalize(Tensor0 &acc_s, float softmax_scale_log2, float rp_dropout=1.0) {
constexpr static float max_offset_E = Use_max_offset ? 8.0f * float(M_LN2) : 0.0f;
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows); static_assert(decltype(size<0>(scores))::value == kNRows);
...@@ -203,7 +211,7 @@ struct Softmax { ...@@ -203,7 +211,7 @@ struct Softmax {
for (int mi = 0; mi < size(row_max); ++mi) { for (int mi = 0; mi < size(row_max); ++mi) {
float sum = row_sum(mi); float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum; float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum;
row_sum(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); row_sum(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : (row_max(mi) * softmax_scale_log2) * float(M_LN2) - max_offset_E + __logf(sum);
scores_scale(mi) = !Is_dropout ? inv_sum : inv_sum * rp_dropout; scores_scale(mi) = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
} }
return scores_scale; return scores_scale;
......
...@@ -24,9 +24,9 @@ def print_diffs(out, out_ref): ...@@ -24,9 +24,9 @@ def print_diffs(out, out_ref):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
...@@ -38,6 +38,7 @@ def print_diffs(out, out_ref): ...@@ -38,6 +38,7 @@ def print_diffs(out, out_ref):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seqlen_q,seqlen_k", "seqlen_q,seqlen_k",
[ [
(1, 1),
(257, 1), (257, 1),
(64, 128), (64, 128),
(128, 128), (128, 128),
...@@ -53,28 +54,43 @@ def print_diffs(out, out_ref): ...@@ -53,28 +54,43 @@ def print_diffs(out, out_ref):
(1024, 1024), (1024, 1024),
(1023, 1024), (1023, 1024),
(1024, 1023), (1024, 1023),
(2048, 2048), (4096, 4096),
], ],
) )
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
def test_flash_attn_output( def test_flash_attn_output(
seqlen_q, seqlen_k, d, causal, mha_type, dtype seqlen_q, seqlen_k, d, causal, mha_type, dtype,
): ):
device = "cuda" device = "cuda"
if(dtype == torch.float8_e4m3fn):
dtype_init = torch.float16
else:
dtype_init = dtype
print(dtype)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
# batch_size = 40 # batch_size = 40
# nheads = 16 # nheads = 16
batch_size = 9 batch_size = 4
nheads = 6 nheads = 6
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
# nheads_kv = 2 # nheads_kv = 2
# batch_size = 9 # batch_size = 9
# nheads = 6 # nheads = 6
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
q = q.to(dtype)
k = k.to(dtype)
v = v.to(dtype)
out, lse = flash_attn_func(q, k, v, causal=causal) out, lse = flash_attn_func(q, k, v, causal=causal)
q = q.to(dtype_init)
k = k.to(dtype_init)
v = v.to(dtype_init)
out_ref, attn_ref = attention_ref( out_ref, attn_ref = attention_ref(
q, q,
k, k,
...@@ -105,8 +121,9 @@ def test_flash_attn_output( ...@@ -105,8 +121,9 @@ def test_flash_attn_output(
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# if not causal: # if not causal:
# print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
# breakpoint() # breakpoint()
# if d <= 128: # if d <= 128:
...@@ -139,7 +156,11 @@ def test_flash_attn_output( ...@@ -139,7 +156,11 @@ def test_flash_attn_output(
# Check that FlashAttention's numerical error is at most twice the numerical error # Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation. # of a Pytorch implementation.
# breakpoint() # breakpoint()
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if(dtype != torch.float8_e4m3fn):
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
else:
# just test correctness of fp8 kernel w/o further quantization techniques
assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()
# if d <= 128: # if d <= 128:
# assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() # assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
......
...@@ -164,7 +164,7 @@ public: ...@@ -164,7 +164,7 @@ public:
}; };
template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup> template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads = cutlass::NumThreadsPerWarp>
class DynamicPersistentTileScheduler { class DynamicPersistentTileScheduler {
protected: protected:
...@@ -228,13 +228,13 @@ public: ...@@ -228,13 +228,13 @@ public:
CUTLASS_DEVICE CUTLASS_DEVICE
void void
init_consumer() const { init_consumer() const {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
} }
CUTLASS_DEVICE CUTLASS_DEVICE
void void
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { if (threadIdx.x % NumProducerThreads == 0) {
current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
} }
} }
...@@ -242,24 +242,28 @@ public: ...@@ -242,24 +242,28 @@ public:
CUTLASS_DEVICE CUTLASS_DEVICE
void void
broadcast_next_work(WorkTileInfo& current_work) const { broadcast_next_work(WorkTileInfo& current_work) const {
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { if (threadIdx.x % NumProducerThreads == 0) {
*tile_count_smem = current_work.tile_idx; *tile_count_smem = current_work.tile_idx;
} }
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/); cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
} }
template<bool IsProducer=false> template<bool IsProducer=false>
CUTLASS_DEVICE CUTLASS_DEVICE
WorkTileInfo WorkTileInfo
get_next_work(Params const& params, WorkTileInfo const& current_work) const { get_next_work(Params const& params, WorkTileInfo const& current_work) const {
if constexpr (IsProducer) { if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarp) {
// thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 // thread 0 already has the right tile_idx, just need to broadcast to the rest of the producer threads (warp 0)
return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)}; return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)};
} else if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarpGroup) {
// TODO: investigate optimal synchronize
int tile_idx = *tile_count_smem;
return {tile_idx};
} else { } else {
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/); cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
int tile_idx = *tile_count_smem; int tile_idx = *tile_count_smem;
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
return {tile_idx}; return {tile_idx};
} }
} }
......
...@@ -143,6 +143,38 @@ __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { ...@@ -143,6 +143,38 @@ __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
} }
}; };
// Convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_Aregs_fp8(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _4>{}); // (2, 2, (2, N / 32)))
return make_layout(make_layout(Shape<_4, _2, _2>{}),
get<1>(acc_layout),
make_layout(get<2, 1>(l), get<2>(acc_layout)));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Byte permute for fp8 kernel
template <typename Fragment>
CUTLASS_DEVICE void permute_regs_A_to_C(Fragment &accum) {
auto data = accum.data();
#pragma unroll
for (int n = 0; n < size(accum); n += 8) {
uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
auto upper = data_32bit[0];
auto lower = data_32bit[1];
data_32bit[0] = __byte_perm(upper, lower, 0x5410);
data_32bit[1] = __byte_perm(upper, lower, 0x7632);
}
}
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout> template <typename To_type, typename Engine, typename Layout>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment