Commit 422a69b2 authored by letaoqin's avatar letaoqin
Browse files

add static check for vector load

parent 6c971dc8
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......
......@@ -88,6 +88,9 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
......@@ -96,6 +96,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
......@@ -87,6 +87,9 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
......@@ -95,6 +95,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
......@@ -97,6 +97,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
......
......@@ -88,6 +88,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment