Unverified Commit fd249aac authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

Add Sparse Decoding Kernel and Sparse Prefill Kernel for Blackwell


Signed-off-by: default avatarsimon-mo <simon.mo@hey.com>
parent 17944550
...@@ -33,6 +33,8 @@ python tests/test_flash_mla_decoding.py ...@@ -33,6 +33,8 @@ python tests/test_flash_mla_decoding.py
The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8. The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8.
For Blackwell GPUs, the token-level sparse MLA decoding kernel can achieve up to 350 TFlops (on B200) which is not really optimized yet.
#### Test & benchmark MHA prefill (Dense): #### Test & benchmark MHA prefill (Dense):
```bash ```bash
...@@ -47,7 +49,7 @@ It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation ...@@ -47,7 +49,7 @@ It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation
python tests/test_flash_mla_prefill.py python tests/test_flash_mla_prefill.py
``` ```
It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8. It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9.
## Requirements ## Requirements
...@@ -60,9 +62,9 @@ Support matrix: ...@@ -60,9 +62,9 @@ Support matrix:
| Kernel | GPU Architecture | MLA Mode [2] | KVCache Format | | Kernel | GPU Architecture | MLA Mode [2] | KVCache Format |
| :---: | :---: | :---: | :---: | | :---: | :---: | :---: | :---: |
| Dense Decoding | Hopper | MQA | BF16 | | Dense Decoding | Hopper | MQA | BF16 |
| Sparse Decoding | Hopper | MQA | FP8 [1] | | Sparse Decoding | Hopper & Blackwell | MQA | FP8 [1] |
| Dense Prefill | Blackwell | MHA | | | Dense Prefill | Blackwell | MHA | |
| Sparse Prefill | Hopper | MQA | | | Sparse Prefill | Hopper & Blackwell | MQA | |
[1]: For more details on using FP8 KV cache, see documents below. [1]: For more details on using FP8 KV cache, see documents below.
......
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
#include "sm90/decode/dense/splitkv_mla.h" #include "sm90/decode/dense/splitkv_mla.h"
#include "sm90/decode/sparse_fp8/splitkv_mla.h" #include "sm90/decode/sparse_fp8/splitkv_mla.h"
#include "sm90/prefill/sparse/fwd.h" #include "sm90/prefill/sparse/fwd.h"
#include "sm100/decode/sparse_fp8/splitkv_mla.h"
#include "sm100/prefill/dense/interface.h" #include "sm100/prefill/dense/interface.h"
#include "sm100/prefill/sparse/fwd.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
...@@ -31,7 +33,7 @@ struct Arch { ...@@ -31,7 +33,7 @@ struct Arch {
} }
bool is_sm100() const { bool is_sm100() const {
return major == 10 && minor == 0; return major == 10;
} }
void assert_is_supported() const { void assert_is_supported() const {
...@@ -86,7 +88,31 @@ DecodingAttnImplMeta get_attn_impl_meta( ...@@ -86,7 +88,31 @@ DecodingAttnImplMeta get_attn_impl_meta(
} }
} }
} else if (arch.is_sm100()) { } else if (arch.is_sm100()) {
TORCH_CHECK(false, "Unsupported GPU architecture"); if (is_sparse_attn) {
if (is_fp8_kvcache) {
TORCH_CHECK(h_q_.has_value());
int h_q = h_q_.value();
TORCH_CHECK(h_q % h_k == 0);
int s_q = num_q_tokens_per_head_k * h_k / h_q;
// FP8 + Sparse MLA
return {
std::max(sm_count / h_k / (cutlass::ceil_div(h_q/h_k, 64) * s_q), 1),
5,
64
};
} else {
// Sparse BF16 MLA
TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM100");
}
} else {
if (is_fp8_kvcache) {
// FP8 MLA
TORCH_CHECK(false, "FP8 Dence MLA is not supported on SM100");
} else {
// Normal BF16 MLA
TORCH_CHECK(false, "BF16 Dence MLA is not supported on SM100");
}
}
} else { } else {
TORCH_CHECK(false, "Unsupported GPU architecture"); TORCH_CHECK(false, "Unsupported GPU architecture");
} }
...@@ -326,7 +352,8 @@ fwd_kvcache_mla( ...@@ -326,7 +352,8 @@ fwd_kvcache_mla(
} }
} }
} else if (arch.is_sm100()) { } else if (arch.is_sm100()) {
TORCH_CHECK(false, "Unsupported GPU architecture"); TORCH_CHECK(is_fp8 && is_sparse_attn, "Only FP8 + Sparse attention is supported on SM100");
sm100::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream);
} else { } else {
TORCH_CHECK(false, "Unsupported GPU architecture"); TORCH_CHECK(false, "Unsupported GPU architecture");
} }
...@@ -366,7 +393,8 @@ std::vector<at::Tensor> sparse_prefill_fwd( ...@@ -366,7 +393,8 @@ std::vector<at::Tensor> sparse_prefill_fwd(
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9; bool is_sm90 = dprops->major == 9;
TORCH_CHECK(is_sm90, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 architectures"); bool is_sm100 = dprops->major == 10;
TORCH_CHECK(is_sm90 || is_sm100, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 or SM100 architectures");
CHECK_DEVICE(q); CHECK_DEVICE(q);
CHECK_DEVICE(kv); CHECK_DEVICE(kv);
...@@ -423,6 +451,8 @@ std::vector<at::Tensor> sparse_prefill_fwd( ...@@ -423,6 +451,8 @@ std::vector<at::Tensor> sparse_prefill_fwd(
if (is_sm90) { if (is_sm90) {
sm90::run_fwd_kernel(params); sm90::run_fwd_kernel(params);
} else if (is_sm100) {
sm100::run_fwd_kernel(params);
} else { } else {
TORCH_CHECK(false, "Unknown architecture"); TORCH_CHECK(false, "Unknown architecture");
} }
......
#pragma once
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include "sm100/defines.h"
namespace sm100 {
struct fp8x8 {
__nv_fp8x4_e4m3 lo;
__nv_fp8x4_e4m3 hi;
};
struct fp8x32 {
fp8x8 a0, a1, a2, a3;
};
struct fp8x16 {
fp8x8 a0, a1;
};
__device__ __forceinline__
bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) {
__nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale);
#define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \
{ \
float4 fp32x4 = (float4)(FP8x4); \
OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \
OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \
}
bf16x8 result;
DEQUANT_FP8x4(result.a01, result.a23, inputs.lo);
DEQUANT_FP8x4(result.a45, result.a67, inputs.hi);
return result;
}
__device__ __forceinline__
fp8x32 ldg_256_fp8x32(void* src_ptr) {
int32x8_t val;
asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];"
: "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3),
"=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7)
: "l"(src_ptr)
);
return *reinterpret_cast<fp8x32*>(&val);
}
__device__ __forceinline__
fp8x16 ldg_128_fp8x16(void* src_ptr) {
int4 ret;
asm volatile("ld.global.nc.L1::evict_first.v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w)
: "l"(src_ptr));
return *reinterpret_cast<fp8x16*>(&ret);
}
}
This diff is collapsed.
#pragma once
#include "params.h"
namespace sm100 {
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream);
}
#pragma once
#include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h>
namespace sm100 {
using bf16 = cutlass::bfloat16_t;
using fp8 = cutlass::float_e4m3_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::fence_barrier_init;
using cutlass::arch::NamedBarrier;
struct int32x8_t {
int a0, a1, a2, a3, a4, a5, a6, a7;
};
struct float8 {
float2 a01, a23, a45, a67;
};
struct bf16x8 {
__nv_bfloat162 a01;
__nv_bfloat162 a23;
__nv_bfloat162 a45;
__nv_bfloat162 a67;
};
}
#pragma once
#include <cute/tensor.hpp>
#include "defines.h"
namespace sm100 {
using namespace cute;
using _72 = Int<72>;
using _576 = Int<576>;
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE
void launch_tma_copy(
const TMA &tma_copy,
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
}
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma_ss(
TiledMMA &tiled_mma,
TensorA sA,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sA_frag = thr_mma.partition_fragment_A(sA);
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
static_assert(size<1>(sA_frag) == size<1>(tC_frag));
static_assert(size<1>(sB_frag) == size<2>(tC_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(
tiled_mma,
sA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma_ts(
TiledMMA &tiled_mma,
TensorA tA_frag,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(tA_frag) == size<2>(sB_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(tA_frag); ++k) {
cute::gemm(
tiled_mma,
tA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
}
This diff is collapsed.
This diff is collapsed.
#pragma once
#include "params.h"
namespace sm100 {
void run_fwd_kernel(const SparsePrefillParams& params);
}
#pragma once
#include <cute/tensor.hpp>
#include "sm100/defines.h"
namespace sm100 {
using namespace cute;
using _72 = Int<72>;
using _576 = Int<576>;
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE
void launch_tma_copy(
const TMA &tma_copy,
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
}
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma(
TiledMMA &tiled_mma,
TensorA sA,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sA_frag = thr_mma.partition_fragment_A(sA);
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
static_assert(size<1>(sA_frag) == size<1>(tC_frag));
static_assert(size<1>(sB_frag) == size<2>(tC_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(
tiled_mma,
sA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma_ts(
TiledMMA &tiled_mma,
TensorA tA_frag,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(tA_frag) == size<2>(sB_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(tA_frag); ++k) {
cute::gemm(
tiled_mma,
tA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
struct bf16x8 {
__nv_bfloat162 a01;
__nv_bfloat162 a23;
__nv_bfloat162 a45;
__nv_bfloat162 a67;
};
}
This diff is collapsed.
#pragma once
#include <cute/tensor.hpp>
namespace cute {
// Extensions to CuTe
// CuTe don't support UTCMMA with .ws, so we add it here
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_WS_SS_NOELECT
{
static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.");
static_assert(N == 64 || N == 128 || N == 256,
"SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;
// Logical shape-K is always 256bits, transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
UMMA::Saturate c_sat = UMMA::Saturate::False>
struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 32 between 32 and 256.");
static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed");
using DRegisters = void;
using ARegisters = uint32_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint32_t const& tmem_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg,
UMMA::Saturate c_sat>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions' K extent is always 256 bits; convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync()
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 32 between 32 and 256.");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
// template <class a_type, class b_type, class c_type,
// int M, int N, UMMA::Major a_major, UMMA::Major b_major,
// UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
// struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
// M, N, a_major, b_major,
// a_neg, b_neg>> : MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS<a_type, b_type, c_type,
// M, N, a_major, b_major,
// a_neg, b_neg>> {};
template <class a_type, class b_type, class c_type,
int M, int N,
UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions's K extent is always 256bits, convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
}
\ No newline at end of file
#pragma once
#include <cute/tensor.hpp>
namespace cute {
// Extensions to CuTe
// CuTe 自带的 SM100_TMA_2SM_LOAD_1D 系列要求参与的 thread 数量为 2(using ThrID = Layout<_2>;),还会对数据进行切分,用起来太恶心了,所以我们自己改一版。另外,为了和其他使用 SM90 TMA 的部分统一,这里我们让它接受 TMA::CacheHintSm90 而不是 TMA::CacheHintSm100
////////////////////////////////////////////////////////////////////////////////////////////////////
/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM100_TMA_2SM_LOAD_1D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3}], [%2], %4;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_2D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0, int32_t const& crd1)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4}], [%2], %5;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_3D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5}], [%2], %6;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_4D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_5D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0)
{
return SM100_TMA_2SM_LOAD_1D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
return SM100_TMA_2SM_LOAD_2D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
return SM100_TMA_2SM_LOAD_3D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
return SM100_TMA_2SM_LOAD_4D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
return SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4);
}
using PREFETCH = typename SM90_TMA_LOAD::PREFETCH;
};
struct SM100_TMA_2SM_LOAD_NOSPLIT_OP : SM100_TMA_2SM_LOAD_NOSPLIT {};
// The non-executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and no tma_mbar
// Use .with(tma_mbar) to construct an executable version
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_NOSPLIT arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
with(
uint64_t& tma_mbar,
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {&tma_desc_, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
with(
TmaDescriptor const* new_tma_desc,
uint64_t& tma_mbar,
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {new_tma_desc, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
// Don't try to execute a copy with SM100_TMA_2SM_LOAD_NOSPLIT before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
// The executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and tma_mbar
template <class NumBitsPerTMA>
struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_NOSPLIT arguments
tuple<
TmaDescriptor const*,
uint64_t*, // smem mbarrier
uint64_t // cache hint
> const opargs_;
CUTE_HOST_DEVICE
Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache)
: opargs_(desc, mbar, cache) {}
};
}
#pragma once
#include <cute/tensor.hpp>
namespace cute {
// Extensions to CuTe
// CuTe don't support UTCMMA with .ws, so we add it here
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_WS_SS_NOELECT
{
static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.");
static_assert(N == 64 || N == 128 || N == 256,
"SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;
// Logical shape-K is always 256bits, transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
using namespace cute;
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_WS_TS_NOELECT
{
static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_TS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.");
static_assert(N == 64 || N == 128 || N == 256,
"SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 32, 64 or 128");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint32_t const& tmem_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], [%1], %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_WS_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_WS_TS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_1sm<c_type, int32_t, UMMA::TmemAllocMode::NonInterleaved>;
// Logical shape-K is always 256 bits; transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint32_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_WS_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
UMMA::Saturate c_sat = UMMA::Saturate::False>
struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 32 between 32 and 256.");
static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed");
using DRegisters = void;
using ARegisters = uint32_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint32_t const& tmem_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg,
UMMA::Saturate c_sat>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions' K extent is always 256 bits; convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync()
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 32 between 32 and 256.");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
// template <class a_type, class b_type, class c_type,
// int M, int N, UMMA::Major a_major, UMMA::Major b_major,
// UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
// struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
// M, N, a_major, b_major,
// a_neg, b_neg>> : MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS<a_type, b_type, c_type,
// M, N, a_major, b_major,
// a_neg, b_neg>> {};
template <class a_type, class b_type, class c_type,
int M, int N,
UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions's K extent is always 256bits, convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
} // namespace cute
...@@ -9,6 +9,7 @@ from torch.utils.cpp_extension import ( ...@@ -9,6 +9,7 @@ from torch.utils.cpp_extension import (
BuildExtension, BuildExtension,
CUDAExtension, CUDAExtension,
IS_WINDOWS, IS_WINDOWS,
CUDA_HOME
) )
...@@ -22,8 +23,21 @@ def get_features_args(): ...@@ -22,8 +23,21 @@ def get_features_args():
return features_args return features_args
def get_arch_flags(): def get_arch_flags():
# Check NVCC Version
# NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py`
assert CUDA_HOME is not None, "PyTorch must be compiled with CUDA support"
nvcc_version = subprocess.check_output(
[os.path.join(CUDA_HOME, "bin", "nvcc"), '--version'], stderr=subprocess.STDOUT
).decode('utf-8')
nvcc_version_number = nvcc_version.split('release ')[1].split(',')[0].strip()
major, minor = map(int, nvcc_version_number.split('.'))
print(f'Compiling using NVCC {major}.{minor}')
DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100")
DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90")
if major < 12 or (major == 12 and minor <= 8):
assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment."
arch_flags = [] arch_flags = []
if not DISABLE_SM100: if not DISABLE_SM100:
arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"])
...@@ -55,8 +69,10 @@ ext_modules.append( ...@@ -55,8 +69,10 @@ ext_modules.append(
"csrc/sm90/decode/dense/splitkv_mla.cu", "csrc/sm90/decode/dense/splitkv_mla.cu",
"csrc/sm90/decode/sparse_fp8/splitkv_mla.cu", "csrc/sm90/decode/sparse_fp8/splitkv_mla.cu",
"csrc/sm90/prefill/sparse/fwd.cu", "csrc/sm90/prefill/sparse/fwd.cu",
"csrc/sm100/decode/sparse_fp8/splitkv_mla.cu",
"csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu",
"csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu",
"csrc/sm100/prefill/sparse/fwd.cu",
], ],
extra_compile_args={ extra_compile_args={
"cxx": cxx_args + get_features_args(), "cxx": cxx_args + get_features_args(),
......
...@@ -320,6 +320,11 @@ def main(torch_dtype): ...@@ -320,6 +320,11 @@ def main(torch_dtype):
testcases = correctness_cases + corner_cases + performance_cases testcases = correctness_cases + corner_cases + performance_cases
# Prune out unsupported cases
cc_major, cc_minor = torch.cuda.get_device_capability()
if cc_major == 10:
testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)]
for testcase in testcases: for testcase in testcases:
test_flash_mla(testcase) test_flash_mla(testcase)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment