"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "a5232a7f2ed96b6157b8f9d728ab3637869daa1c"
Commit 00cb7e41 authored by danyao12's avatar danyao12
Browse files

modify comment

parent c07c2b55
......@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -78,7 +78,7 @@ using GemmDataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = INT32; // INT32
using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -89,7 +89,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
......@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
static constexpr bool Deterministic = false;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
......
......@@ -879,7 +879,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
n4)); // RegisterNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
......@@ -889,7 +889,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4, // registerNum
n4, // RegisterNum
I1)); // I1
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
......@@ -902,7 +902,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
n4)); // RegisterNum
constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
......@@ -974,7 +974,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>, // registerNum
n4>, // RegisterNum
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
......@@ -982,12 +982,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
1, // DstScalarStrideInVector
true>{
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0, // mrepeat
0, // nrepeat
make_multi_index(0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
0, // NGroupIndex
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
......@@ -1003,8 +1003,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
1,
1,
true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_multi_index(0, // mrepeat
0, // nrepeat
make_multi_index(0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1] / ZN4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment