Commit 5ba7dc40 authored by danyao12's avatar danyao12
Browse files

add grouped pt1 qloop

parent 30716b7c
...@@ -8,6 +8,7 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe ...@@ -8,6 +8,7 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe
add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp) add_example_executable(example_grouped_multihead_attention_forward grouped_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp) add_example_executable(example_batched_multihead_attention_forward batched_multihead_attention_forward.cpp)
add_example_executable(example_grouped_multihead_attention_backward grouped_multihead_attention_backward.cpp) add_example_executable(example_grouped_multihead_attention_backward grouped_multihead_attention_backward.cpp)
add_example_executable(example_grouped_multihead_attention_backward_v2 grouped_multihead_attention_backward_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp) add_example_executable(example_batched_multihead_attention_backward batched_multihead_attention_backward.cpp)
add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp) add_example_executable(example_batched_multihead_attention_backward_v2 batched_multihead_attention_backward_v2.cpp)
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp) add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
......
...@@ -63,9 +63,9 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -63,9 +63,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using InputDataType = BF16; using InputDataType = F16;
using OutputDataType = F32; using OutputDataType = F16;
using GemmDataType = BF16; using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
...@@ -80,7 +80,7 @@ static constexpr ck::index_t NumDimK = 1; ...@@ -80,7 +80,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4 // When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8 // When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4; static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -95,7 +95,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia ...@@ -95,7 +95,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true; static constexpr bool Deterministic = false;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
...@@ -511,7 +511,7 @@ int run(int argc, char* argv[]) ...@@ -511,7 +511,7 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2; float p_drop = 0.0;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
......
...@@ -62,9 +62,9 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -62,9 +62,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using InputDataType = BF16; using InputDataType = F16;
using OutputDataType = F32; using OutputDataType = F16;
using GemmDataType = BF16; using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
...@@ -79,7 +79,7 @@ static constexpr ck::index_t NumDimK = 1; ...@@ -79,7 +79,7 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4 // When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8 // When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4; static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -94,7 +94,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia ...@@ -94,7 +94,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true; static constexpr bool Deterministic = false;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment