Commit 422a69b2 authored by letaoqin's avatar letaoqin
Browse files

add static check for vector load

parent 6c971dc8
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
...@@ -88,6 +88,9 @@ template <typename InputDataType, ...@@ -88,6 +88,9 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(KPerBlock == Gemm1NPerBlock); static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0); static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0); static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
...@@ -96,6 +96,10 @@ template <typename InputDataType, ...@@ -96,6 +96,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(Gemm1NPerBlock % KPerBlock == 0); static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0); static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0); static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
...@@ -87,6 +87,9 @@ template <typename InputDataType, ...@@ -87,6 +87,9 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(KPerBlock == Gemm1NPerBlock); static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0); static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0); static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
...@@ -95,6 +95,10 @@ template <typename InputDataType, ...@@ -95,6 +95,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(Gemm1NPerBlock % KPerBlock == 0); static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0); static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0); static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
...@@ -97,6 +97,10 @@ template <typename FloatAB, ...@@ -97,6 +97,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(D0BlockTransferSrcScalarPerVector == 1 || static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 || D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4, D0BlockTransferSrcScalarPerVector == 4,
......
...@@ -88,6 +88,10 @@ template <typename FloatAB, ...@@ -88,6 +88,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(D0BlockTransferSrcScalarPerVector == 1 || static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 || D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4, D0BlockTransferSrcScalarPerVector == 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment