Commit 00cb7e41 authored by danyao12's avatar danyao12
Browse files

modify comment

parent c07c2b55
...@@ -32,7 +32,7 @@ Kernel outputs: ...@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -78,7 +78,7 @@ using GemmDataType = F16; ...@@ -78,7 +78,7 @@ using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = INT32; // INT32 using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -89,7 +89,7 @@ static constexpr ck::index_t NumDimK = 1; ...@@ -89,7 +89,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4 // When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8 // When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4; static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia ...@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true; static constexpr bool Deterministic = false;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
......
...@@ -879,7 +879,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -879,7 +879,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2, // MPerXdl m2, // MPerXdl
n2, // NGroupNum n2, // NGroupNum
n3, // NInputNum n3, // NInputNum
n4)); // registerNum n4)); // RegisterNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
...@@ -889,7 +889,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -889,7 +889,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2, // MPerXdl m2, // MPerXdl
n2, // NGroupNum n2, // NGroupNum
n3, // NInputNum n3, // NInputNum
n4, // registerNum n4, // RegisterNum
I1)); // I1 I1)); // I1
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
...@@ -902,7 +902,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -902,7 +902,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2, // MPerXdl m2, // MPerXdl
n2, // NGroupNum n2, // NGroupNum
n3, // NInputNum n3, // NInputNum
n4)); // registerNum n4)); // RegisterNum
constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
...@@ -974,7 +974,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -974,7 +974,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2, // MPerXdl m2, // MPerXdl
n2, // NGroupNum n2, // NGroupNum
n3, // NInputNum n3, // NInputNum
n4>, // registerNum n4>, // RegisterNum
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim 7, // DstVectorDim
1, // DstScalarPerVector 1, // DstScalarPerVector
...@@ -982,12 +982,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -982,12 +982,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{ true>{
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0, // mrepeat make_multi_index(0, // MRepeat
0, // nrepeat 0, // NRepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl wave_m_n_id[I1], // MPerXdl
0, // group 0, // NGroupIndex
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
...@@ -1003,8 +1003,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1003,8 +1003,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
1, 1,
1, 1,
true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4, true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_multi_index(0, // mrepeat make_multi_index(0, // MRepeat
0, // nrepeat 0, // NRepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1] / ZN4, wave_m_n_id[I1] / ZN4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment